feat: ws connect (#3)

* chore: ws

* chore: build client stream

* feat: test ws connect

* ci: fix ci
This commit is contained in:
Nathan.fooo 2023-05-08 19:03:50 +08:00 committed by GitHub
parent 08847fad1d
commit 18e950a829
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
55 changed files with 2144 additions and 1868 deletions

2
.gitignore vendored
View file

@ -9,3 +9,5 @@
package-lock.json
yarn.lock
node_modules
**/crates/AppFlowy-Collab/
data/

449
Cargo.lock generated
View file

@ -89,7 +89,7 @@ dependencies = [
"mime",
"percent-encoding",
"pin-project-lite",
"rand",
"rand 0.8.5",
"sha1",
"smallvec",
"tokio",
@ -189,7 +189,7 @@ dependencies = [
"anyhow",
"async-trait",
"derive_more",
"rand",
"rand 0.8.5",
"redis",
"serde",
"serde_json",
@ -370,7 +370,7 @@ version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47"
dependencies = [
"getrandom",
"getrandom 0.2.9",
"once_cell",
"version_check",
]
@ -382,7 +382,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f"
dependencies = [
"cfg-if",
"getrandom",
"getrandom 0.2.9",
"once_cell",
"version_check",
]
@ -446,6 +446,9 @@ dependencies = [
"bincode",
"bytes",
"chrono",
"collab-client-ws",
"collab-persistence",
"collab-sync",
"config",
"dashmap",
"derive_more",
@ -454,7 +457,7 @@ dependencies = [
"lazy_static",
"once_cell",
"openssl",
"rand",
"rand 0.8.5",
"rcgen",
"reqwest",
"secrecy",
@ -474,6 +477,7 @@ dependencies = [
"unicode-segmentation",
"uuid",
"validator",
"websocket",
]
[[package]]
@ -574,6 +578,12 @@ dependencies = [
"num-traits",
]
[[package]]
name = "atomic_refcell"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79d6dc922a2792b006573f60b2648076355daeae5ce9cb59507e5908c9625d31"
[[package]]
name = "autocfg"
version = "1.1.0"
@ -613,6 +623,26 @@ dependencies = [
"serde",
]
[[package]]
name = "bindgen"
version = "0.64.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4243e6031260db77ede97ad86c27e501d646a27ab57b59a574f725d98ab1fb4"
dependencies = [
"bitflags",
"cexpr",
"clang-sys",
"lazy_static",
"lazycell",
"peeking_take_while",
"proc-macro2",
"quote",
"regex",
"rustc-hash",
"shlex",
"syn 1.0.109",
]
[[package]]
name = "bit-set"
version = "0.5.3"
@ -700,6 +730,17 @@ dependencies = [
"bytes",
]
[[package]]
name = "bzip2-sys"
version = "0.1.11+1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc"
dependencies = [
"cc",
"libc",
"pkg-config",
]
[[package]]
name = "cc"
version = "1.0.79"
@ -709,6 +750,15 @@ dependencies = [
"jobserver",
]
[[package]]
name = "cexpr"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
dependencies = [
"nom",
]
[[package]]
name = "cfg-if"
version = "1.0.0"
@ -738,6 +788,17 @@ dependencies = [
"inout",
]
[[package]]
name = "clang-sys"
version = "1.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c688fc74432808e3eb684cae8830a86be1d66a2bd58e1f248ed0960a590baf6f"
dependencies = [
"glob",
"libc",
"libloading",
]
[[package]]
name = "codespan-reporting"
version = "0.11.1"
@ -748,6 +809,91 @@ dependencies = [
"unicode-width",
]
[[package]]
name = "collab"
version = "0.1.0"
dependencies = [
"anyhow",
"bytes",
"lib0",
"parking_lot 0.12.1",
"serde",
"serde_json",
"thiserror",
"tracing",
"y-sync",
"yrs",
]
[[package]]
name = "collab-client-ws"
version = "0.1.0"
dependencies = [
"bytes",
"futures-util",
"serde",
"serde_json",
"thiserror",
"tokio",
"tokio-retry",
"tokio-stream",
"tokio-tungstenite",
"tracing",
]
[[package]]
name = "collab-persistence"
version = "0.1.0"
dependencies = [
"bincode",
"chrono",
"lazy_static",
"lib0",
"parking_lot 0.12.1",
"rocksdb",
"serde",
"sled",
"smallvec",
"thiserror",
"tokio",
"tracing",
"yrs",
]
[[package]]
name = "collab-plugins"
version = "0.1.0"
dependencies = [
"collab",
"collab-client-ws",
"collab-persistence",
"collab-sync",
"tracing",
"y-sync",
"yrs",
]
[[package]]
name = "collab-sync"
version = "0.1.0"
dependencies = [
"bytes",
"collab",
"futures-util",
"lib0",
"md5",
"parking_lot 0.12.1",
"serde",
"serde_json",
"thiserror",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
"y-sync",
"yrs",
]
[[package]]
name = "combine"
version = "4.6.6"
@ -793,7 +939,7 @@ dependencies = [
"hkdf",
"hmac",
"percent-encoding",
"rand",
"rand 0.8.5",
"sha2",
"subtle",
"time",
@ -914,7 +1060,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
dependencies = [
"generic-array",
"rand_core",
"rand_core 0.6.4",
"typenum",
]
@ -1308,6 +1454,19 @@ dependencies = [
"winapi",
]
[[package]]
name = "getrandom"
version = "0.1.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce"
dependencies = [
"cfg-if",
"js-sys",
"libc",
"wasi 0.9.0+wasi-snapshot-preview1",
"wasm-bindgen",
]
[[package]]
name = "getrandom"
version = "0.2.9"
@ -1316,7 +1475,7 @@ checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4"
dependencies = [
"cfg-if",
"libc",
"wasi",
"wasi 0.11.0+wasi-snapshot-preview1",
]
[[package]]
@ -1329,6 +1488,12 @@ dependencies = [
"polyval",
]
[[package]]
name = "glob"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
[[package]]
name = "h2"
version = "0.3.18"
@ -1635,12 +1800,65 @@ version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]]
name = "lazycell"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]]
name = "lib0"
version = "0.16.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "daf23122cb1c970b77ea6030eac5e328669415b65d2ab245c99bfb110f9d62dc"
dependencies = [
"serde",
"serde_json",
"thiserror",
]
[[package]]
name = "libc"
version = "0.2.142"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a987beff54b60ffa6d51982e1aa1146bc42f19bd26be28b0586f252fccf5317"
[[package]]
name = "libloading"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b67380fd3b2fbe7527a606e18729d21c6f3951633d0500574c4dc22d2d638b9f"
dependencies = [
"cfg-if",
"winapi",
]
[[package]]
name = "librocksdb-sys"
version = "0.10.0+7.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fe4d5874f5ff2bc616e55e8c6086d478fcda13faf9495768a4aa1c22042d30b"
dependencies = [
"bindgen",
"bzip2-sys",
"cc",
"glob",
"libc",
"libz-sys",
"zstd-sys",
]
[[package]]
name = "libz-sys"
version = "1.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56ee889ecc9568871456d42f603d6a0ce59ff328d291063a45cbdf0036baf6db"
dependencies = [
"cc",
"pkg-config",
"vcpkg",
]
[[package]]
name = "link-cplusplus"
version = "1.0.8"
@ -1723,6 +1941,12 @@ dependencies = [
"digest",
]
[[package]]
name = "md5"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771"
[[package]]
name = "memchr"
version = "2.5.0"
@ -1767,7 +1991,7 @@ checksum = "5b9d9a46eff5b4ff64b45a9e316a6d1e0bc719ef429cbec4dc630684212bfdf9"
dependencies = [
"libc",
"log",
"wasi",
"wasi 0.11.0+wasi-snapshot-preview1",
"windows-sys 0.45.0",
]
@ -1975,7 +2199,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166"
dependencies = [
"base64ct",
"rand_core",
"rand_core 0.6.4",
"subtle",
]
@ -1991,6 +2215,12 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8835116a5c179084a830efb3adc117ab007512b535bc1a21c991d3b32a6b44dd"
[[package]]
name = "peeking_take_while"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099"
[[package]]
name = "pem"
version = "1.1.1"
@ -2096,6 +2326,19 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "rand"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03"
dependencies = [
"getrandom 0.1.16",
"libc",
"rand_chacha 0.2.2",
"rand_core 0.5.1",
"rand_hc",
]
[[package]]
name = "rand"
version = "0.8.5"
@ -2103,8 +2346,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
"rand_chacha 0.3.1",
"rand_core 0.6.4",
]
[[package]]
name = "rand_chacha"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402"
dependencies = [
"ppv-lite86",
"rand_core 0.5.1",
]
[[package]]
@ -2114,7 +2367,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core",
"rand_core 0.6.4",
]
[[package]]
name = "rand_core"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
dependencies = [
"getrandom 0.1.16",
]
[[package]]
@ -2123,7 +2385,16 @@ version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
"getrandom",
"getrandom 0.2.9",
]
[[package]]
name = "rand_hc"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
dependencies = [
"rand_core 0.5.1",
]
[[package]]
@ -2186,7 +2457,7 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b"
dependencies = [
"getrandom",
"getrandom 0.2.9",
"redox_syscall 0.2.16",
"thiserror",
]
@ -2264,17 +2535,6 @@ dependencies = [
"winreg",
]
[[package]]
name = "revdb"
version = "0.1.0"
dependencies = [
"bincode",
"serde",
"sled",
"tempfile",
"thiserror",
]
[[package]]
name = "ring"
version = "0.16.20"
@ -2290,6 +2550,22 @@ dependencies = [
"winapi",
]
[[package]]
name = "rocksdb"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "015439787fce1e75d55f279078d33ff14b4af5d93d995e8838ee4631301c8a99"
dependencies = [
"libc",
"librocksdb-sys",
]
[[package]]
name = "rustc-hash"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "rustc_version"
version = "0.4.0"
@ -2504,6 +2780,12 @@ dependencies = [
"lazy_static",
]
[[package]]
name = "shlex"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43b2853a4d09f215c24cc5489c992ce46052d359b5109343cbafbf26bc62f8a3"
[[package]]
name = "signal-hook-registry"
version = "1.4.1"
@ -2538,6 +2820,15 @@ dependencies = [
"parking_lot 0.11.2",
]
[[package]]
name = "smallstr"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e922794d168678729ffc7e07182721a14219c65814e66e91b839a272fe5ae4f"
dependencies = [
"smallvec",
]
[[package]]
name = "smallvec"
version = "1.10.0"
@ -2621,7 +2912,7 @@ dependencies = [
"once_cell",
"paste",
"percent-encoding",
"rand",
"rand 0.8.5",
"rustls",
"rustls-pemfile",
"serde",
@ -2881,6 +3172,17 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-retry"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f57eb36ecbe0fc510036adff84824dd3c24bb781e21bfa67b69d556aa85214f"
dependencies = [
"pin-project",
"rand 0.8.5",
"tokio",
]
[[package]]
name = "tokio-rustls"
version = "0.23.4"
@ -2901,6 +3203,19 @@ dependencies = [
"futures-core",
"pin-project-lite",
"tokio",
"tokio-util",
]
[[package]]
name = "tokio-tungstenite"
version = "0.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54319c93411147bced34cb5609a80e0a8e44c5999c93903a81cd866630ec0bfd"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite",
]
[[package]]
@ -3035,6 +3350,25 @@ version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed"
[[package]]
name = "tungstenite"
version = "0.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30ee6ab729cd4cf0fd55218530c4522ed30b7b6081752839b68fcec8d0960788"
dependencies = [
"base64 0.13.1",
"byteorder",
"bytes",
"http",
"httparse",
"log",
"rand 0.8.5",
"sha1",
"thiserror",
"url",
"utf-8",
]
[[package]]
name = "typenum"
version = "1.16.0"
@ -3113,13 +3447,19 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]]
name = "uuid"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4dad5567ad0cf5b760e5665964bec1b47dfd077ba8a2544b513f3556d3d239a2"
dependencies = [
"getrandom",
"getrandom 0.2.9",
"serde",
]
@ -3166,6 +3506,12 @@ dependencies = [
"try-lock",
]
[[package]]
name = "wasi"
version = "0.9.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519"
[[package]]
name = "wasi"
version = "0.11.0+wasi-snapshot-preview1"
@ -3267,6 +3613,28 @@ dependencies = [
"webpki",
]
[[package]]
name = "websocket"
version = "0.1.0"
dependencies = [
"actix",
"actix-web-actors",
"bytes",
"collab",
"collab-persistence",
"collab-plugins",
"collab-sync",
"dashmap",
"futures-util",
"parking_lot 0.12.1",
"secrecy",
"serde",
"thiserror",
"tokio",
"tokio-stream",
"tracing",
]
[[package]]
name = "whoami"
version = "1.4.0"
@ -3492,6 +3860,17 @@ dependencies = [
"time",
]
[[package]]
name = "y-sync"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f54d34b68ec4514a0659838c2b1ba867c571b20b3804a1338dacf4fa9062d801"
dependencies = [
"lib0",
"thiserror",
"yrs",
]
[[package]]
name = "yaml-rust"
version = "0.4.5"
@ -3510,6 +3889,20 @@ dependencies = [
"time",
]
[[package]]
name = "yrs"
version = "0.16.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c2aef2bf89b4f7c003f9c73f1c8097427ca32e1d006443f3f607f11e79a797b"
dependencies = [
"atomic_refcell",
"lib0",
"rand 0.7.3",
"smallstr",
"smallvec",
"thiserror",
]
[[package]]
name = "zeroize"
version = "1.6.0"

View file

@ -51,6 +51,8 @@ bytes = "1.4.0"
bincode = "1.3.3"
dashmap = "5.4"
rcgen = { version = "0.10.0", features = ["pem", "x509-parser"] }
collab-sync = {version = "0.1.0"}
collab-persistence = {version = "0.1.0"}
# tracing
tracing = { version = "0.1.37" }
@ -63,9 +65,11 @@ sqlx = { version = "0.6", default-features = false, features = ["runtime-actix-r
#Local crate
token = { path = "./crates/token" }
snowflake = { path = "./crates/snowflake" }
websocket = { path = "./crates/websocket" }
[dev-dependencies]
once_cell = "1.7.2"
collab-client-ws = { version = "0.1.0" }
[[bin]]
name = "appflowy_server"
@ -78,6 +82,19 @@ path = "src/lib.rs"
[workspace]
members = [
"crates/token",
"crates/revdb",
"crates/snowflake",
"crates/websocket",
]
[patch.crates-io]
collab = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a12ed" }
collab-client-ws = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a12ed" }
collab-sync= { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a12ed" }
collab-persistence = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a12ed" }
collab-plugins = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a12ed" }
#collab = { path = "./crates/AppFlowy-Collab/collab" }
#collab-client-ws = { path = "./crates/AppFlowy-Collab/collab-client-ws" }
#collab-sync = { path = "./crates/AppFlowy-Collab/collab-sync" }
#collab-persistence = { path = "./crates/AppFlowy-Collab/collab-persistence" }
#collab-plugins = { path = "./crates/AppFlowy-Collab/collab-plugins"}

View file

@ -3,6 +3,7 @@ application:
host: 0.0.0.0
server_key: "Should-Use-The-Custom-Server-Key"
tls_config: "no_tls"
data_dir: "./data"
database:
host: "localhost"
port: 5432

View file

@ -1,15 +0,0 @@
[package]
name = "revdb"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
sled = "0.34.7"
thiserror = "1.0.30"
serde = { version = "1.0", features = ["derive"] }
bincode = "1.3.3"
[dev-dependencies]
tempfile = "3.4.0"

View file

@ -1,56 +0,0 @@
use crate::document::Document;
use crate::error::RevDBError;
use sled::{Batch, Db, IVec};
use std::path::Path;
pub struct RevDB {
pub(crate) db: Db,
}
impl RevDB {
pub fn open(path: impl AsRef<Path>) -> Result<Self, RevDBError> {
let db = sled::open(path)?;
Ok(Self { db })
}
pub fn document(&self) -> Document {
Document { db: self }
}
pub fn get<K: AsRef<[u8]>>(&self, key: K) -> Result<Option<IVec>, RevDBError> {
let value = self.db.get(key)?;
Ok(value)
}
pub fn batch_get<K: AsRef<[u8]>>(
&self,
from_key: K,
to_key: K,
) -> Result<Vec<IVec>, RevDBError> {
let iter = self.db.range(from_key..=to_key);
let mut items = vec![];
for item in iter {
let (_, value) = item?;
items.push(value)
}
Ok(items)
}
pub fn insert<K: AsRef<[u8]>>(&self, key: K, value: &[u8]) -> Result<(), RevDBError> {
let _ = self.db.insert(key, value)?;
Ok(())
}
pub fn batch_insert<'a, K: AsRef<[u8]>>(
&self,
items: impl IntoIterator<Item = (K, &'a [u8])>,
) -> Result<(), RevDBError> {
let mut batch = Batch::default();
let items = items.into_iter();
items.for_each(|(key, value)| {
batch.insert(key.as_ref(), value);
});
self.db.apply_batch(batch)?;
Ok(())
}
}

View file

@ -1,117 +0,0 @@
use crate::db::RevDB;
use crate::error::RevDBError;
use crate::range::RevRange;
use serde::{Deserialize, Serialize};
pub struct Document<'a> {
pub(crate) db: &'a RevDB,
}
impl<'a> Document<'a> {
pub fn insert(
&self,
uid: i64,
document_id: i64,
value: DocumentRevData,
) -> Result<(), RevDBError> {
let key = make_document_key(uid, document_id, value.rev_id);
self.db.insert(key, &value.to_vec()?)?;
Ok(())
}
pub fn get(
&self,
uid: i64,
document_id: i64,
rev_id: i64,
) -> Result<Option<DocumentRevData>, RevDBError> {
let key = make_document_key(uid, document_id, rev_id);
match self.db.get(key)? {
None => Ok(None),
Some(value) => {
let data = DocumentRevData::from_vec(value.as_ref())?;
Ok(Some(data))
}
}
}
pub fn get_with_range<R: Into<RevRange>>(
&self,
uid: i64,
document_id: i64,
range: R,
) -> Result<Vec<DocumentRevData>, RevDBError> {
let range = range.into();
let from = make_document_key(uid, document_id, range.start);
let to = make_document_key(uid, document_id, range.end);
self.batch_get(from, to)
}
pub fn get_after(
&self,
uid: i64,
document_id: i64,
rev_id: i64,
) -> Result<Vec<DocumentRevData>, RevDBError> {
let from = make_document_key(uid, document_id, rev_id);
let to = make_document_key(uid, document_id, i64::MAX);
self.batch_get(from, to)
}
pub fn get_before(
&self,
uid: i64,
document_id: i64,
rev_id: i64,
) -> Result<Vec<DocumentRevData>, RevDBError> {
let from = make_document_key(uid, document_id, 0);
let to = make_document_key(uid, document_id, rev_id);
self.batch_get(from, to)
}
fn batch_get<K: AsRef<[u8]>>(
&self,
from: K,
to: K,
) -> Result<Vec<DocumentRevData>, RevDBError> {
let items = self.db.batch_get(from, to)?;
let mut document_revs = vec![];
for item in items {
let rev_data = DocumentRevData::from_vec(item.as_ref())?;
document_revs.push(rev_data);
}
Ok(document_revs)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DocumentRevData {
#[serde(rename = "rid")]
pub rev_id: i64,
#[serde(rename = "bid")]
pub base_rev_id: i64,
#[serde(rename = "data")]
pub content: String,
}
impl DocumentRevData {
pub fn from_vec(data: &[u8]) -> Result<Self, RevDBError> {
bincode::deserialize::<Self>(data).map_err(|_e| RevDBError::SerdeError)
}
pub fn to_vec(&self) -> Result<Vec<u8>, RevDBError> {
bincode::serialize(self).map_err(|_e| RevDBError::SerdeError)
}
}
// Optimize your data layout: Sled's B-Tree implementation works best when the keys are sequential,
// so try to organize the data in a way that maximizes sequential access.
fn make_document_key(uid: i64, document_id: i64, rev_id: i64) -> [u8; 24] {
let mut key = [0; 24];
key[0..8].copy_from_slice(&uid.to_be_bytes());
key[8..16].copy_from_slice(&document_id.to_be_bytes());
key[16..24].copy_from_slice(&rev_id.to_be_bytes());
key
}

View file

@ -1,11 +0,0 @@
#[derive(Debug, thiserror::Error)]
pub enum RevDBError {
#[error(transparent)]
Db(#[from] sled::Error),
#[error("Serde error")]
SerdeError,
#[error("invalid data")]
InvalidData,
}

View file

@ -1,4 +0,0 @@
pub mod db;
pub mod document;
pub mod error;
pub mod range;

View file

@ -1,48 +0,0 @@
use std::ops::{Range, RangeInclusive, RangeToInclusive};
#[derive(Clone)]
pub struct RevRange {
pub(crate) start: i64,
pub(crate) end: i64,
}
impl RevRange {
/// Construct a new `RevRange` representing the range [start..end).
/// It is an invariant that `start <= end`.
pub fn new(start: i64, end: i64) -> RevRange {
debug_assert!(start <= end);
RevRange { start, end }
}
}
impl From<RangeInclusive<i64>> for RevRange {
fn from(src: RangeInclusive<i64>) -> RevRange {
RevRange::new(*src.start(), src.end().saturating_add(1))
}
}
impl From<RangeToInclusive<i64>> for RevRange {
fn from(src: RangeToInclusive<i64>) -> RevRange {
RevRange::new(0, src.end.saturating_add(1))
}
}
impl From<Range<i64>> for RevRange {
fn from(src: Range<i64>) -> RevRange {
let Range { start, end } = src;
RevRange { start, end }
}
}
impl Iterator for RevRange {
type Item = i64;
fn next(&mut self) -> Option<i64> {
if self.start > self.end {
return None;
}
let val = self.start;
self.start += 1;
Some(val)
}
}

View file

@ -1,2 +0,0 @@
mod test;
mod util;

View file

@ -1,136 +0,0 @@
use crate::document::util::make_test_db;
use revdb::document::{Document, DocumentRevData};
use revdb::range::RevRange;
#[test]
fn insert_text() {
let db = make_test_db();
let document = db.document();
let uid = 12345678;
let document_id = 1;
let value = DocumentRevData {
rev_id: 0,
base_rev_id: 0,
content: "hello world".to_string(),
};
document.insert(uid, document_id, value.clone()).unwrap();
let restored_data = document.get(uid, document_id, 0).unwrap().unwrap();
assert_eq!(value.content, restored_data.content);
}
//noinspection RsExternalLinter
#[test]
fn insert_multi_text() {
let db = make_test_db();
let document = db.document();
let uid = 12345678;
let document_id = 1;
let mut base_rev_id = 0;
let mut expected_str = "".to_string();
for i in 0..=100 {
let content = i.to_string();
expected_str.push_str(&content);
let value = DocumentRevData {
rev_id: i,
base_rev_id,
content,
};
base_rev_id += 1;
document.insert(uid, document_id, value).unwrap();
}
//
let restored_str = document
.get_with_range(uid, document_id, RevRange::new(0, 100))
.unwrap()
.into_iter()
.map(|data| data.content)
.collect::<Vec<String>>()
.join("");
assert_eq!(expected_str, restored_str);
}
//noinspection RsExternalLinter
fn insert_100_string_to_document(uid: i64, document_id: i64, document: &Document) {
let mut base_rev_id = 0;
for i in 0..=100 {
let content = i.to_string();
let value = DocumentRevData {
rev_id: i,
base_rev_id,
content,
};
base_rev_id += 1;
document.insert(uid, document_id, value).unwrap();
}
}
fn values_to_string(values: Vec<DocumentRevData>) -> String {
values
.into_iter()
.map(|data| data.content)
.collect::<Vec<String>>()
.join("")
}
#[test]
fn get_value_before() {
let db = make_test_db();
let document = db.document();
let uid = 12345678;
let document_id = 1;
insert_100_string_to_document(uid, document_id, &document);
let restored_str = values_to_string(document.get_before(uid, document_id, 50).unwrap());
assert_eq!("01234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950", restored_str);
let restored_str = values_to_string(document.get_before(uid, document_id, 0).unwrap());
assert_eq!("0", restored_str);
}
#[test]
fn get_value_after() {
let db = make_test_db();
let document = db.document();
let uid = 12345678;
let document_id = 1;
insert_100_string_to_document(uid, document_id, &document);
let restored_str = values_to_string(document.get_after(uid, document_id, 50).unwrap());
assert_eq!("5051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100", restored_str);
let restored_str = values_to_string(document.get_after(uid, document_id, 100).unwrap());
assert_eq!("100", restored_str);
}
#[test]
fn get_value_with_range() {
let db = make_test_db();
let document = db.document();
let uid = 12345678;
let document_id = 1;
insert_100_string_to_document(uid, document_id, &document);
let restored_str = values_to_string(
document
.get_with_range(uid, document_id, RevRange::new(50, 60))
.unwrap(),
);
assert_eq!("5051525354555657585960", restored_str);
let restored_str = values_to_string(
document
.get_with_range(uid, document_id, RevRange::new(50, 200))
.unwrap(),
);
assert_eq!("5051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100", restored_str);
let restored_str = values_to_string(
document
.get_with_range(uid, document_id, RevRange::new(50, 50))
.unwrap(),
);
assert_eq!("50", restored_str);
}

View file

@ -1,7 +0,0 @@
use revdb::db::RevDB;
use tempfile::TempDir;
pub fn make_test_db() -> RevDB {
let tempdir = TempDir::new().unwrap();
RevDB::open(tempdir).unwrap()
}

View file

@ -1 +0,0 @@
mod document;

View file

@ -38,8 +38,7 @@ impl Snowflake {
}
self.last_timestamp = timestamp;
let id =
(timestamp - EPOCH) << TIMESTAMP_SHIFT | self.node_id << NODE_ID_SHIFT | self.sequence;
let id = (timestamp - EPOCH) << TIMESTAMP_SHIFT | self.node_id << NODE_ID_SHIFT | self.sequence;
id as i64
}

View file

@ -33,11 +33,13 @@ pub fn create_token(
data: impl Serialize,
expire_duration: Duration,
) -> Result<String, TokenError> {
Ok(TokenFields {
Ok(
TokenFields {
data,
expire_at: Utc::now() + expire_duration,
}
.sign_with_key(&generate_hmac_key(server_key))?)
.sign_with_key(&generate_hmac_key(server_key))?,
)
}
fn generate_hmac_key(server_key: &str) -> Hmac<Sha256> {

View file

@ -0,0 +1,25 @@
[package]
name = "websocket"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
actix = "0.13"
actix-web-actors = { version = "4.2.0" }
serde = { version = "1.0", features = ["derive"] }
thiserror = "1.0.30"
bytes = "1.0"
secrecy = { version = "0.8", features = ["serde"] }
parking_lot = "0.12.1"
tracing = "0.1.25"
futures-util = "0.3.26"
tokio-stream = { version = "0.1.14", features = ["sync"] }
tokio = { version = "1.26", features = ["sync"] }
dashmap = "5.4.0"
collab = { version = "0.1.0"}
collab-sync = { version = "0.1.0"}
collab-persistence = { version = "0.1.0"}
collab-plugins = { version = "0.1.0", features = ["disk_rocksdb"]}

View file

@ -0,0 +1,168 @@
use crate::entities::{ClientMessage, Connect, Disconnect, ServerMessage, WSUser};
use crate::error::WSError;
use crate::CollabServer;
use actix::{
fut, Actor, ActorContext, ActorFutureExt, Addr, AsyncContext, ContextFutureSpawner, Handler,
Recipient, Running, StreamHandler, WrapFuture,
};
use actix_web_actors::ws;
use bytes::Bytes;
use collab_sync::msg::CollabMessage;
use futures_util::Sink;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
const CLIENT_TIMEOUT: Duration = Duration::from_secs(10);
pub struct CollabSession {
user: Arc<WSUser>,
hb: Instant,
pub server: Addr<CollabServer>,
}
impl CollabSession {
pub fn new(user: WSUser, server: Addr<CollabServer>) -> Self {
Self {
user: Arc::new(user),
hb: Instant::now(),
server,
}
}
fn hb(&self, ctx: &mut ws::WebsocketContext<Self>) {
ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| {
if Instant::now().duration_since(act.hb) > CLIENT_TIMEOUT {
act.server.do_send(Disconnect {
user: act.user.clone(),
});
ctx.stop();
return;
}
ctx.ping(b"");
});
}
fn send_to_server(&self, bytes: Bytes) {
match CollabMessage::from_vec(bytes.to_vec()) {
Ok(collab_msg) => {
self.server.do_send(ClientMessage {
user: self.user.clone(),
collab_msg,
});
},
Err(e) => {
tracing::error!("Error parsing message: {:?}", e);
},
}
}
}
impl Actor for CollabSession {
type Context = ws::WebsocketContext<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
// start heartbeats otherwise server disconnects in 10 seconds
self.hb(ctx);
self
.server
.send(Connect {
socket: ctx.address().recipient(),
user: self.user.clone(),
})
.into_actor(self)
.then(|res, _session, ctx| {
match res {
Ok(Ok(_)) => {
tracing::trace!("Send connect message to server success")
},
_ => {
tracing::error!("Send connect message to server failed");
ctx.stop();
},
}
fut::ready(())
})
.wait(ctx);
}
fn stopping(&mut self, _: &mut Self::Context) -> Running {
self.server.do_send(Disconnect {
user: self.user.clone(),
});
Running::Stop
}
}
impl Handler<ServerMessage> for CollabSession {
type Result = ();
fn handle(&mut self, msg: ServerMessage, ctx: &mut Self::Context) {
ctx.binary(msg.collab_msg);
}
}
/// WebSocket message handler
impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for CollabSession {
fn handle(&mut self, msg: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
let msg = match msg {
Err(_) => {
ctx.stop();
return;
},
Ok(msg) => msg,
};
match msg {
ws::Message::Ping(msg) => {
self.hb = Instant::now();
ctx.pong(&msg);
},
ws::Message::Pong(_) => {
self.hb = Instant::now();
},
ws::Message::Text(_) => {},
ws::Message::Binary(bytes) => {
self.send_to_server(bytes);
},
ws::Message::Close(reason) => {
ctx.close(reason);
ctx.stop();
},
ws::Message::Continuation(_) => {
ctx.stop();
},
ws::Message::Nop => (),
}
}
}
/// A helper struct that wraps the [Recipient] type to implement the [Sink] trait
pub struct ClientSink(pub Recipient<ServerMessage>);
impl Sink<CollabMessage> for ClientSink {
type Error = WSError;
fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn start_send(self: Pin<&mut Self>, item: CollabMessage) -> Result<(), Self::Error> {
self.0.do_send(ServerMessage { collab_msg: item });
Ok(())
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
}

View file

@ -0,0 +1,56 @@
use crate::error::WSError;
use actix::{Message, Recipient};
use collab_sync::msg::CollabMessage;
use secrecy::{ExposeSecret, Secret};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct WSUser {
pub user_id: Secret<String>,
}
impl Hash for WSUser {
fn hash<H: Hasher>(&self, state: &mut H) {
let uid: &String = self.user_id.expose_secret();
uid.hash(state);
}
}
impl PartialEq<Self> for WSUser {
fn eq(&self, other: &Self) -> bool {
let uid: &String = self.user_id.expose_secret();
let other_uid: &String = other.user_id.expose_secret();
uid == other_uid
}
}
impl Eq for WSUser {}
#[derive(Debug, Message, Clone)]
#[rtype(result = "Result<(), WSError>")]
pub struct Connect {
pub socket: Recipient<ServerMessage>,
pub user: Arc<WSUser>,
}
#[derive(Debug, Message, Clone)]
#[rtype(result = "Result<(), WSError>")]
pub struct Disconnect {
pub user: Arc<WSUser>,
}
#[derive(Debug, Message, Clone)]
#[rtype(result = "()")]
pub struct ClientMessage {
pub user: Arc<WSUser>,
pub collab_msg: CollabMessage,
}
#[derive(Debug, Message, Clone)]
#[rtype(result = "()")]
pub struct ServerMessage {
pub collab_msg: CollabMessage,
}

View file

@ -0,0 +1,8 @@
#[derive(Debug, thiserror::Error)]
pub enum WSError {
#[error(transparent)]
Persistence(#[from] collab_persistence::error::PersistenceError),
#[error("Internal failure: {0}")]
Internal(#[from] Box<dyn std::error::Error + Send + Sync>),
}

View file

@ -0,0 +1,7 @@
mod client;
pub mod entities;
mod error;
mod server;
pub use client::*;
pub use server::*;

View file

@ -0,0 +1,190 @@
use crate::entities::{ClientMessage, Connect, Disconnect, WSUser};
use crate::error::WSError;
use crate::ClientSink;
use actix::{Actor, Context, Handler, ResponseFuture};
use collab::core::collab::MutexCollab;
use collab::core::origin::CollabOrigin;
use collab_persistence::kv::rocks_kv::RocksCollabDB;
use collab_persistence::kv::KVStore;
use collab_plugins::disk_plugin::rocksdb_server::RocksdbServerDiskPlugin;
use collab_sync::server::{
CollabBroadcast, CollabGroup, CollabIDGen, CollabId, NonZeroNodeId, COLLAB_ID_LEN,
};
use dashmap::DashMap;
use parking_lot::{Mutex, RwLock};
use std::collections::HashMap;
use collab_persistence::keys::make_collab_id_key;
use collab_sync::msg::CollabMessage;
use std::sync::Arc;
use tokio::sync::mpsc::Sender;
use tokio_stream::wrappers::ReceiverStream;
#[derive(Clone)]
pub struct CollabServer {
db: Arc<RocksCollabDB>,
/// Generate collab_id for new collab object
collab_id_gen: Arc<Mutex<CollabIDGen>>,
/// Memory cache for fast lookup of collab_id from object_id
collab_id_by_object_id: Arc<DashMap<String, CollabId>>,
collab_groups: Arc<RwLock<HashMap<CollabId, CollabGroup>>>,
client_streams: Arc<RwLock<HashMap<Arc<WSUser>, ClientStream>>>,
}
impl CollabServer {
pub fn new(db: Arc<RocksCollabDB>) -> Result<Self, WSError> {
let collab_id_gen = Arc::new(Mutex::new(CollabIDGen::new(NonZeroNodeId(1))));
let collab_id_by_object_id = Arc::new(DashMap::new());
Ok(Self {
db,
collab_id_gen,
collab_id_by_object_id,
collab_groups: Default::default(),
client_streams: Default::default(),
})
}
fn create_collab_id(&self, object_id: &str) -> Result<CollabId, WSError> {
let collab_id = self.collab_id_gen.lock().next_id();
let collab_key = make_collab_id_key(object_id.as_ref());
self.db.with_write_txn(|w_txn| {
w_txn.insert(collab_key.as_ref(), collab_id.to_be_bytes())?;
Ok(())
})?;
Ok(collab_id)
}
fn get_collab_id(&self, object_id: &str) -> Option<CollabId> {
let collab_key = make_collab_id_key(object_id.as_ref());
let read_txn = self.db.read_txn();
let value = read_txn.get(collab_key.as_ref()).ok()??;
let mut bytes = [0; COLLAB_ID_LEN];
bytes[0..COLLAB_ID_LEN].copy_from_slice(value.as_ref());
Some(CollabId::from_be_bytes(bytes))
}
fn get_or_create_collab_id(&self, object_id: &str) -> Result<CollabId, WSError> {
let collab_id = self.get_collab_id(object_id);
if let Some(collab_id) = collab_id {
self.create_group_if_need(collab_id, object_id);
Ok(collab_id)
} else {
let collab_id = self.create_collab_id(object_id)?;
self
.collab_id_by_object_id
.insert(object_id.to_string(), collab_id);
self.create_group_if_need(collab_id, object_id);
Ok(collab_id)
}
}
fn create_group_if_need(&self, collab_id: CollabId, object_id: &str) {
if self.collab_groups.read().contains_key(&collab_id) {
return;
}
let collab = MutexCollab::new(CollabOrigin::Empty, object_id, vec![]);
let plugin = RocksdbServerDiskPlugin::new(collab_id, self.db.clone()).unwrap();
collab.lock().add_plugin(Arc::new(plugin));
collab.initial();
let broadcast = CollabBroadcast::new(object_id, collab.clone(), 10);
let group = CollabGroup {
collab,
broadcast,
subscribers: Default::default(),
};
self.collab_groups.write().insert(collab_id, group);
}
}
impl Actor for CollabServer {
type Context = Context<Self>;
}
impl Handler<Connect> for CollabServer {
type Result = Result<(), WSError>;
fn handle(&mut self, msg: Connect, _ctx: &mut Context<Self>) -> Self::Result {
let (stream_tx, rx) = tokio::sync::mpsc::channel(100);
let stream = ClientStream::new(ClientSink(msg.socket), ReceiverStream::new(rx), stream_tx);
self.client_streams.write().insert(msg.user, stream);
Ok(())
}
}
impl Handler<Disconnect> for CollabServer {
type Result = Result<(), WSError>;
fn handle(&mut self, msg: Disconnect, _: &mut Context<Self>) -> Self::Result {
self.client_streams.write().remove(&msg.user);
Ok(())
}
}
impl Handler<ClientMessage> for CollabServer {
type Result = ResponseFuture<()>;
fn handle(&mut self, msg: ClientMessage, _ctx: &mut Context<Self>) -> Self::Result {
let object_id = msg.collab_msg.object_id();
if let Ok(collab_id) = self.get_or_create_collab_id(object_id) {
if let Some(collab_group) = self.collab_groups.write().get_mut(&collab_id) {
if let Some(stream) = self.client_streams.write().get_mut(&msg.user) {
if let Some((sink, stream)) = stream.split() {
let origin = match msg.collab_msg.origin() {
None => CollabOrigin::Empty,
Some(client) => client.clone(),
};
let sub = collab_group
.broadcast
.subscribe(origin.clone(), sink, stream);
collab_group.subscribers.insert(origin, sub);
}
}
}
let client_streams = self.client_streams.clone();
Box::pin(async move {
if let Some(client_stream) = client_streams.read().get(&msg.user) {
let _ = client_stream.stream_tx.send(Ok(msg.collab_msg)).await;
}
})
} else {
Box::pin(async move {})
}
}
}
impl actix::Supervised for CollabServer {
fn restarting(&mut self, _ctx: &mut Context<CollabServer>) {
tracing::warn!("restarting");
}
}
pub struct ClientStream {
sink: Option<ClientSink>,
stream: Option<ReceiverStream<Result<CollabMessage, WSError>>>,
stream_tx: Sender<Result<CollabMessage, WSError>>,
}
impl ClientStream {
pub fn new(
sink: ClientSink,
stream: ReceiverStream<Result<CollabMessage, WSError>>,
stream_tx: Sender<Result<CollabMessage, WSError>>,
) -> Self {
Self {
sink: Some(sink),
stream: Some(stream),
stream_tx,
}
}
pub fn split(&mut self) -> Option<(ClientSink, ReceiverStream<Result<CollabMessage, WSError>>)> {
let sink = self.sink.take()?;
let stream = self.stream.take()?;
Some((sink, stream))
}
}

12
rustfmt.toml Normal file
View file

@ -0,0 +1,12 @@
# https://rust-lang.github.io/rustfmt/?version=master&search=
max_width = 100
tab_spaces = 2
newline_style = "Auto"
match_block_trailing_comma = true
use_field_init_shorthand = true
use_try_shorthand = true
reorder_imports = true
reorder_modules = true
remove_nested_parens = true
merge_derives = true
edition = "2021"

View file

@ -1,10 +1,12 @@
use crate::component::auth::LoggedUser;
use crate::component::ws::{MessageReceivers, WSClient, WSServer};
use crate::state::State;
use actix::Addr;
use actix_web::web::{Data, Path, Payload};
use actix_web::{get, web, HttpRequest, HttpResponse, Result, Scope};
use actix_web_actors::ws;
use secrecy::Secret;
use websocket::entities::WSUser;
use websocket::{CollabServer, CollabSession};
pub fn ws_scope() -> Scope {
web::scope("/ws").service(establish_ws_connection)
@ -16,17 +18,23 @@ pub async fn establish_ws_connection(
payload: Payload,
token: Path<String>,
state: Data<State>,
server: Data<Addr<WSServer>>,
msg_receivers: Data<MessageReceivers>,
server: Data<Addr<CollabServer>>,
) -> Result<HttpResponse> {
tracing::info!("establish_ws_connection");
let user = LoggedUser::from_token(&state.config.application.server_key, token.as_str())?;
let client = WSClient::new(user, server.get_ref().clone(), msg_receivers);
let client = CollabSession::new(user.into(), server.get_ref().clone());
match ws::start(client, &request, payload) {
Ok(response) => Ok(response),
Err(e) => {
tracing::error!("ws connection error: {:?}", e);
Err(e)
},
}
}
impl From<LoggedUser> for WSUser {
fn from(user: LoggedUser) -> Self {
Self {
user_id: Secret::new(user.expose_secret().to_string()),
}
}
}

View file

@ -10,6 +10,9 @@ use actix_session::SessionMiddleware;
use actix_web::cookie::Key;
use actix_web::{dev::Server, web, web::Data, App, HttpServer};
use actix::Actor;
use collab_persistence::kv::rocks_kv::RocksCollabDB;
use openssl::ssl::{SslAcceptor, SslAcceptorBuilder, SslFiletype, SslMethod};
use openssl::x509::X509;
use secrecy::{ExposeSecret, Secret};
@ -18,7 +21,9 @@ use sqlx::{postgres::PgPoolOptions, PgPool};
use std::net::TcpListener;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing_actix_web::TracingLogger;
use websocket::CollabServer;
pub struct Application {
port: u16,
@ -63,20 +68,23 @@ pub async fn run(
.as_ref()
.map(|(_, server_key)| Key::from(server_key.expose_secret().as_bytes()))
.unwrap_or_else(Key::generate);
let collab_server = CollabServer::new(state.rocksdb.clone()).unwrap().start();
let mut server = HttpServer::new(move || {
App::new()
// Session middleware
.wrap(
SessionMiddleware::builder(redis_store.clone(), key.clone())
.cookie_name(HEADER_TOKEN.to_string())
.build(),
)
// .wrap(ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, add_error_header))
.wrap(IdentityMiddleware::default())
.wrap(default_cors())
.wrap(TracingLogger::default())
.app_data(web::JsonConfig::default().limit(4096))
.service(user_scope())
.service(ws_scope())
.app_data(Data::new(collab_server.clone()))
.app_data(Data::new(state.clone()))
});
@ -84,7 +92,7 @@ pub async fn run(
None => server.listen(listener)?,
Some((certificate, _)) => {
server.listen_openssl(listener, make_ssl_acceptor_builder(certificate))?
}
},
};
Ok(server.run())
@ -103,8 +111,11 @@ pub async fn init_state(config: &Config) -> State {
.await
.unwrap_or_else(|_| panic!("Failed to connect to Postgres at {:?}.", config.database));
std::fs::create_dir_all(config.application.rocksdb_db_dir()).expect("create rocksdb db dir");
let rocksdb = Arc::new(RocksCollabDB::open(config.application.rocksdb_db_dir()).unwrap());
State {
pg_pool,
rocksdb,
config: Arc::new(config.clone()),
user: Arc::new(Default::default()),
id_gen: Arc::new(RwLock::new(Snowflake::new(1))),
@ -136,3 +147,10 @@ fn make_ssl_acceptor_builder(certificate: Secret<String>) -> SslAcceptorBuilder
.unwrap();
builder
}
// fn add_error_header<B>(
// res: dev::ServiceResponse<B>,
// ) -> Result<ErrorHandlerResponse<B>, actix_web::Error> {
// tracing::error!("{:?}", res.request());
// Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
// }

View file

@ -38,7 +38,8 @@ pub async fn validate_credentials(
.await
.context("Failed to spawn blocking task.")??;
uid.ok_or_else(|| anyhow::anyhow!("Unknown email."))
uid
.ok_or_else(|| anyhow::anyhow!("Unknown email."))
.map_err(AuthError::InvalidCredentials)
}

View file

@ -40,7 +40,7 @@ pub async fn login(
},
Secret::new(token),
))
}
},
Err(err) => Err(err),
}
}
@ -210,13 +210,13 @@ pub struct ChangePasswordRequest {
}
#[derive(Clone, Default)]
pub struct WrapI64(i64);
impl Copy for WrapI64 {}
impl DefaultIsZeroes for WrapI64 {}
impl DebugSecret for WrapI64 {}
impl CloneableSecret for WrapI64 {}
pub struct SecretI64(i64);
impl Copy for SecretI64 {}
impl DefaultIsZeroes for SecretI64 {}
impl DebugSecret for SecretI64 {}
impl CloneableSecret for SecretI64 {}
impl std::ops::Deref for WrapI64 {
impl std::ops::Deref for SecretI64 {
type Target = i64;
fn deref(&self) -> &Self::Target {
@ -225,17 +225,17 @@ impl std::ops::Deref for WrapI64 {
}
#[derive(Debug, Clone)]
pub struct LoggedUser(Secret<WrapI64>);
pub struct LoggedUser(Secret<SecretI64>);
impl From<Claim> for LoggedUser {
fn from(c: Claim) -> Self {
Self(Secret::new(WrapI64(c.uid)))
Self(Secret::new(SecretI64(c.uid)))
}
}
impl LoggedUser {
pub fn new(uid: i64) -> Self {
Self(Secret::new(WrapI64(uid)))
Self(Secret::new(SecretI64(uid)))
}
pub fn from_token(server_key: &Secret<String>, token: &str) -> Result<Self, AuthError> {

View file

@ -1,3 +1,2 @@
pub mod auth;
pub mod token_state;
pub mod ws;

View file

@ -1,164 +0,0 @@
use crate::component::auth::LoggedUser;
use crate::component::ws::entities::{
Connect, Disconnect, MessageDetail, MessagePayload, Socket, WebSocketMessage,
};
use crate::component::ws::server::WSServer;
use crate::component::ws::{HEARTBEAT_INTERVAL, PING_TIMEOUT};
use actix::*;
use actix_http::ws::Message::*;
use actix_web::web::Data;
use actix_web_actors::ws;
use bytes::Bytes;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
pub trait MessageReceiver: Send + Sync {
fn receive(&self, data: WSClientData);
}
#[derive(Default)]
pub struct MessageReceivers {
inner: HashMap<u8, Arc<dyn MessageReceiver>>,
}
impl MessageReceivers {
pub fn new() -> Self {
MessageReceivers::default()
}
pub fn insert(&mut self, channel: u8, receiver: Arc<dyn MessageReceiver>) {
self.inner.insert(channel, receiver);
}
pub fn get(&self, source: u8) -> Option<&Arc<dyn MessageReceiver>> {
self.inner.get(&source)
}
}
#[allow(dead_code)]
pub struct WSClientData {
pub(crate) socket: Socket,
pub(crate) detail: MessageDetail,
}
pub struct WSClient {
user: Arc<LoggedUser>,
server: Addr<WSServer>,
msg_receivers: Data<MessageReceivers>,
hb: Instant,
}
impl WSClient {
pub fn new(
user: LoggedUser,
server: Addr<WSServer>,
msg_receivers: Data<MessageReceivers>,
) -> Self {
Self {
user: Arc::new(user),
server,
msg_receivers,
hb: Instant::now(),
}
}
fn hb(&self, ctx: &mut ws::WebsocketContext<Self>) {
ctx.run_interval(HEARTBEAT_INTERVAL, |client, ctx| {
if Instant::now().duration_since(client.hb) > PING_TIMEOUT {
client.server.do_send(Disconnect {
user: client.user.clone(),
});
ctx.stop();
} else {
ctx.ping(b"");
}
});
}
fn handle_binary_message(&self, bytes: Bytes, socket: Socket) {
let MessagePayload { channel, detail } = MessagePayload::from_bytes(&bytes);
match self.msg_receivers.get(channel) {
None => {
tracing::error!("Can't find the receiver for {:?}", channel);
}
Some(handler) => {
let client_data = WSClientData { socket, detail };
handler.receive(client_data);
}
}
}
}
impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for WSClient {
fn handle(&mut self, msg: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
match msg {
Ok(Ping(msg)) => {
self.hb = Instant::now();
ctx.pong(&msg);
}
Ok(Pong(_msg)) => {
// tracing::trace!("Receive {} pong {:?}", &self.session_id, &msg);
self.hb = Instant::now();
}
Ok(Binary(bytes)) => {
let socket = ctx.address().recipient();
self.handle_binary_message(bytes, socket);
}
Ok(Text(_)) => {
tracing::warn!("Receive unexpected text message");
}
Ok(Close(reason)) => {
ctx.close(reason);
ctx.stop();
}
Ok(ws::Message::Continuation(_)) => {}
Ok(ws::Message::Nop) => {}
Err(e) => {
tracing::error!("WebSocketStream protocol error {:?}", e);
ctx.stop();
}
}
}
}
impl Handler<WebSocketMessage> for WSClient {
type Result = ();
fn handle(&mut self, msg: WebSocketMessage, ctx: &mut Self::Context) {
ctx.binary(msg.0);
}
}
impl Actor for WSClient {
type Context = ws::WebsocketContext<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
self.hb(ctx);
let socket = ctx.address().recipient();
let connect = Connect {
socket,
user: self.user.clone(),
};
self.server
.send(connect)
.into_actor(self)
.then(|res, _client, _ctx| {
match res {
Ok(Ok(_)) => tracing::trace!("Send connect message to server success"),
Ok(Err(e)) => tracing::error!("Send connect message to server failed: {:?}", e),
Err(e) => tracing::error!("Send connect message to server failed: {:?}", e),
}
fut::ready(())
})
.wait(ctx);
}
fn stopping(&mut self, _: &mut Self::Context) -> Running {
self.server.do_send(Disconnect {
user: self.user.clone(),
});
Running::Stop
}
}

View file

@ -1,92 +0,0 @@
use crate::component::auth::LoggedUser;
use actix::{Message, Recipient};
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use std::fmt::Formatter;
use std::sync::Arc;
pub type Socket = Recipient<WebSocketMessage>;
#[derive(Serialize, Deserialize, Debug, Clone, Hash, PartialEq, Eq)]
pub struct WSSessionId(pub String);
impl<T: AsRef<str>> std::convert::From<T> for WSSessionId {
fn from(s: T) -> Self {
WSSessionId(s.as_ref().to_owned())
}
}
impl std::fmt::Display for WSSessionId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let desc = &self.0.to_string();
f.write_str(desc)
}
}
pub struct Session {
pub user: Arc<LoggedUser>,
pub socket: Socket,
}
impl std::convert::From<Connect> for Session {
fn from(c: Connect) -> Self {
Self {
user: c.user,
socket: c.socket,
}
}
}
#[derive(Debug, Message, Clone)]
#[rtype(result = "Result<(), WSError>")]
pub struct Connect {
pub socket: Socket,
pub user: Arc<LoggedUser>,
}
#[derive(Debug, Message, Clone)]
#[rtype(result = "Result<(), WSError>")]
pub struct Disconnect {
pub user: Arc<LoggedUser>,
}
#[derive(Debug, Message, Clone)]
#[rtype(result = "()")]
pub struct WebSocketMessage(pub Bytes);
impl std::ops::Deref for WebSocketMessage {
type Target = Bytes;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct MessagePayload {
pub(crate) channel: u8,
pub(crate) detail: MessageDetail,
}
impl MessagePayload {
pub fn from_bytes<T: AsRef<[u8]>>(bytes: T) -> Self {
serde_json::from_slice(bytes.as_ref()).unwrap()
}
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum MessageDetail {
Document(MessageContent),
Database(MessageContent),
}
#[derive(Debug, Serialize, Deserialize)]
pub struct MessageContent {
content: String,
}
#[derive(Debug)]
pub enum WSError {
Internal,
}

View file

@ -1,11 +0,0 @@
use std::time::Duration;
mod client;
mod entities;
mod server;
pub use client::*;
pub use server::WSServer;
pub(crate) const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(8);
pub(crate) const PING_TIMEOUT: Duration = Duration::from_secs(60);

View file

@ -1,44 +0,0 @@
use crate::component::ws::entities::{Connect, Disconnect, WSError, WebSocketMessage};
use actix::{Actor, Context, Handler};
#[derive(Default)]
pub struct WSServer {}
impl WSServer {
pub fn new() -> Self {
WSServer::default()
}
pub fn send(&self, _msg: WebSocketMessage) {}
}
impl Actor for WSServer {
type Context = Context<Self>;
fn started(&mut self, _ctx: &mut Self::Context) {}
}
impl Handler<Connect> for WSServer {
type Result = Result<(), WSError>;
fn handle(&mut self, _msg: Connect, _ctx: &mut Context<Self>) -> Self::Result {
Ok(())
}
}
impl Handler<Disconnect> for WSServer {
type Result = Result<(), WSError>;
fn handle(&mut self, _msg: Disconnect, _: &mut Context<Self>) -> Self::Result {
Ok(())
}
}
impl Handler<WebSocketMessage> for WSServer {
type Result = ();
fn handle(&mut self, _msg: WebSocketMessage, _ctx: &mut Context<Self>) -> Self::Result {}
}
impl actix::Supervised for WSServer {
fn restarting(&mut self, _ctx: &mut Context<WSServer>) {
tracing::warn!("restarting");
}
}

View file

@ -3,6 +3,7 @@ use secrecy::Secret;
use serde_aux::field_attributes::deserialize_number_from_string;
use sqlx::postgres::{PgConnectOptions, PgSslMode};
use std::convert::{TryFrom, TryInto};
use std::path::PathBuf;
#[derive(serde::Deserialize, Clone, Debug)]
pub struct Config {
@ -24,6 +25,7 @@ pub struct ApplicationSettings {
#[serde(deserialize_with = "deserialize_number_from_string")]
pub port: u16,
pub host: String,
pub data_dir: PathBuf,
pub server_key: Secret<String>,
pub tls_config: Option<TlsConfig>,
}
@ -38,6 +40,10 @@ impl ApplicationSettings {
},
}
}
pub fn rocksdb_db_dir(&self) -> PathBuf {
self.data_dir.join("rocksdb")
}
}
#[derive(serde::Deserialize, Clone, Debug)]

View file

@ -58,6 +58,6 @@ pub fn validate_password(password: &str) -> bool {
Err(e) => {
tracing::error!("validate_password fail: {:?}", e);
false
}
},
}
}

View file

@ -4,7 +4,11 @@ use appflowy_server::telemetry::{get_subscriber, init_subscriber};
#[actix_web::main]
async fn main() -> anyhow::Result<()> {
let subscriber = get_subscriber("appflowy_server".into(), "info".into(), std::io::stdout);
let subscriber = get_subscriber(
"appflowy_server".to_string(),
"info".to_string(),
std::io::stdout,
);
init_subscriber(subscriber);
let configuration = get_configuration().expect("Failed to read configuration.");

View file

@ -7,7 +7,7 @@ use actix_web::http;
// Cors short for Cross-Origin Resource Sharing.
pub fn default_cors() -> Cors {
Cors::default() // allowed_origin return access-control-allow-origin: * by default
// .allowed_origin("http://127.0.0.1:8080")
.allow_any_origin()
.send_wildcard()
.allowed_methods(vec!["GET", "POST", "PUT", "DELETE"])
.allowed_headers(vec![http::header::ACCEPT])

View file

@ -1,6 +1,8 @@
use crate::component::auth::LoggedUser;
use crate::config::config::Config;
use chrono::{DateTime, Utc};
use collab_persistence::kv::rocks_kv::RocksCollabDB;
use snowflake::Snowflake;
use sqlx::PgPool;
use std::collections::BTreeMap;
@ -10,6 +12,7 @@ use tokio::sync::RwLock;
#[derive(Clone)]
pub struct State {
pub pg_pool: PgPool,
pub rocksdb: Arc<RocksCollabDB>,
pub config: Arc<Config>,
pub user: Arc<RwLock<UserCache>>,
pub id_gen: Arc<RwLock<Snowflake>>,
@ -49,17 +52,17 @@ impl UserCache {
None => {
tracing::debug!("user not login yet or server was reboot");
false
}
},
Some(status) => match *status {
AuthStatus::Authorized(last_time) => {
let current_time = Utc::now();
let days = (current_time - last_time).num_days();
days < EXPIRED_DURATION_DAYS
}
},
AuthStatus::NotAuthorized => {
tracing::debug!("user logout already");
false
}
},
},
}
}
@ -72,7 +75,8 @@ impl UserCache {
}
pub fn unauthorized(&mut self, user: LoggedUser) {
self.user
self
.user
.insert(user.expose_secret().to_owned(), AuthStatus::NotAuthorized);
}
}

View file

@ -4,7 +4,7 @@ use tracing::Subscriber;
use tracing_bunyan_formatter::{BunyanFormattingLayer, JsonStorageLayer};
use tracing_log::LogTracer;
use tracing_subscriber::fmt::MakeWriter;
use tracing_subscriber::{layer::SubscriberExt, EnvFilter, Registry};
use tracing_subscriber::{layer::SubscriberExt, EnvFilter};
/// Compose multiple layers into a `tracing`'s subscriber.
pub fn get_subscriber<Sink>(
@ -15,10 +15,12 @@ pub fn get_subscriber<Sink>(
where
Sink: for<'a> MakeWriter<'a> + Send + Sync + 'static,
{
let env_filter =
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(env_filter));
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(env_filter));
// let env_filter = EnvFilter::new(env_filter);
let formatting_layer = BunyanFormattingLayer::new(name, sink);
Registry::default()
tracing_subscriber::fmt()
.with_ansi(true)
.finish()
.with(env_filter)
.with(JsonStorageLayer)
.with(formatting_layer)

View file

@ -1,8 +1,8 @@
use crate::test_server::{spawn_server, TestUser};
use crate::util::{spawn_server, TestUser};
use actix_web::http::StatusCode;
use appflowy_server::component::auth::LoginResponse;
#[tokio::test]
#[actix_rt::test]
async fn login_success() {
let server = spawn_server().await;
let test_user = TestUser::generate();
@ -16,7 +16,7 @@ async fn login_success() {
assert!(!response.token.is_empty())
}
#[tokio::test]
#[actix_rt::test]
async fn login_with_empty_email() {
let server = spawn_server().await;
let test_user = TestUser::generate();
@ -26,7 +26,7 @@ async fn login_with_empty_email() {
assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
#[actix_rt::test]
async fn login_with_empty_password() {
let server = spawn_server().await;
let test_user = TestUser::generate();
@ -36,7 +36,7 @@ async fn login_with_empty_password() {
assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
#[actix_rt::test]
async fn login_with_unknown_user() {
let server = spawn_server().await;
let http_resp = server.login("unknown@appflowy.io", "Abc@123!").await;

View file

@ -1,4 +1,4 @@
mod login;
mod password;
mod register;
mod test_server;
mod ws;

View file

@ -1,7 +1,7 @@
use crate::test_server::{spawn_server, TestUser};
use crate::util::{spawn_server, TestUser};
use actix_web::http::StatusCode;
#[tokio::test]
#[actix_rt::test]
async fn change_password_with_unmatched_password() {
let server = spawn_server().await;
let test_user = TestUser::generate();
@ -20,7 +20,7 @@ async fn change_password_with_unmatched_password() {
assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
#[actix_rt::test]
async fn login_fail_after_change_password() {
let server = spawn_server().await;
let test_user = TestUser::generate();
@ -36,7 +36,7 @@ async fn login_fail_after_change_password() {
assert_eq!(http_resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
#[actix_rt::test]
async fn login_success_with_new_password() {
let server = spawn_server().await;
let test_user = TestUser::generate();

View file

@ -1,8 +1,8 @@
use crate::test_server::{error_msg_from_resp, spawn_server};
use crate::util::{error_msg_from_resp, spawn_server};
use appflowy_server::component::auth::{InputParamsError, RegisterResponse};
use reqwest::StatusCode;
#[tokio::test]
#[actix_rt::test]
// curl -X POST --url http://0.0.0.0:8000/api/user/register --header 'content-type: application/json' --data '{"name":"fake name", "email":"fake@appflowy.io", "password":"Fake@123"}'
async fn register_success() {
let server = spawn_server().await;
@ -15,7 +15,7 @@ async fn register_success() {
println!("{:?}", response);
}
#[tokio::test]
#[actix_rt::test]
async fn register_with_invalid_password() {
let server = spawn_server().await;
let http_resp = server.register("user 1", "fake@appflowy.io", "123").await;
@ -26,7 +26,7 @@ async fn register_with_invalid_password() {
);
}
#[tokio::test]
#[actix_rt::test]
async fn register_with_invalid_name() {
let server = spawn_server().await;
let name = "".to_string();
@ -40,7 +40,7 @@ async fn register_with_invalid_name() {
);
}
#[tokio::test]
#[actix_rt::test]
async fn register_with_invalid_email() {
let server = spawn_server().await;
let email = "appflowy.io".to_string();

View file

@ -1,189 +0,0 @@
use appflowy_server::application::{init_state, Application};
use appflowy_server::config::config::{get_configuration, DatabaseSetting, TlsConfig};
use appflowy_server::state::State;
use appflowy_server::telemetry::{get_subscriber, init_subscriber};
use once_cell::sync::Lazy;
use reqwest::Certificate;
use appflowy_server::component::auth::{RegisterResponse, HEADER_TOKEN};
use sqlx::types::Uuid;
use sqlx::{Connection, Executor, PgConnection, PgPool};
// Ensure that the `tracing` stack is only initialised once using `once_cell`
static TRACING: Lazy<()> = Lazy::new(|| {
let level = "info".to_string();
let mut filters = vec![];
filters.push(format!("appflowy_server={}", level));
filters.push(format!("hyper={}", level));
let subscriber_name = "test".to_string();
let subscriber = get_subscriber(subscriber_name, filters.join(","), std::io::stdout);
init_subscriber(subscriber);
});
#[derive(Clone)]
pub struct TestServer {
pub state: State,
pub api_client: reqwest::Client,
pub address: String,
pub port: u16,
}
impl TestServer {
pub async fn register(&self, name: &str, email: &str, password: &str) -> reqwest::Response {
let payload = serde_json::json!({
"name": name,
"password": password,
"email": email
});
let url = format!("{}/api/user/register", self.address);
self.api_client
.post(&url)
.json(&payload)
.send()
.await
.expect("Register failed")
}
pub async fn login(&self, email: &str, password: &str) -> reqwest::Response {
let payload = serde_json::json!({
"password": password,
"email": email
});
let url = format!("{}/api/user/login", self.address);
self.api_client
.post(&url)
.json(&payload)
.send()
.await
.expect("Login failed")
}
pub async fn change_password(
&self,
token: String,
current_password: &str,
new_password: &str,
new_password_confirm: &str,
) -> reqwest::Response {
let payload = serde_json::json!({
"current_password": current_password,
"new_password": new_password,
"new_password_confirm": new_password_confirm
});
let url = format!("{}/api/user/password", self.address);
self.api_client
.post(&url)
.header(HEADER_TOKEN, token)
.json(&payload)
.send()
.await
.expect("Change password failed")
}
}
pub async fn spawn_server() -> TestServer {
Lazy::force(&TRACING);
let database_name = Uuid::new_v4().to_string();
let config = {
let mut config = get_configuration().expect("Failed to read configuration.");
config.database.database_name = database_name.clone();
// Use a random OS port
config.application.port = 0;
config
};
let _ = configure_database(&config.database).await;
let state = init_state(&config).await;
let application = Application::build(config.clone(), state.clone())
.await
.expect("Failed to build application");
let port = application.port();
let _ = tokio::spawn(async {
let _ = application.run_until_stopped().await;
});
let mut builder = reqwest::Client::builder();
let mut address = format!("http://localhost:{}", port);
if config.application.use_https() {
address = format!("https://localhost:{}", port);
builder = builder.add_root_certificate(
Certificate::from_pem(include_bytes!("../../cert/cert.pem")).unwrap(),
);
}
let api_client = builder
.add_root_certificate(Certificate::from_pem(include_bytes!("../../cert/cert.pem")).unwrap())
.redirect(reqwest::redirect::Policy::none())
.danger_accept_invalid_certs(true)
.cookie_store(true)
.no_proxy()
.build()
.unwrap();
TestServer {
state,
api_client,
address,
port,
}
}
async fn configure_database(config: &DatabaseSetting) -> PgPool {
// Create database
let mut connection = PgConnection::connect_with(&config.without_db())
.await
.expect("Failed to connect to Postgres");
connection
.execute(&*format!(r#"CREATE DATABASE "{}";"#, config.database_name))
.await
.expect("Failed to create database.");
// Migrate database
let connection_pool = PgPool::connect_with(config.with_db())
.await
.expect("Failed to connect to Postgres.");
sqlx::migrate!("./migrations")
.run(&connection_pool)
.await
.expect("Failed to migrate the database");
connection_pool
}
#[derive(serde::Serialize)]
pub struct TestUser {
name: String,
pub email: String,
pub password: String,
}
impl TestUser {
pub fn generate() -> Self {
Self {
name: "Me".to_string(),
email: "me@appflowy.io".to_string(),
password: "Hello@AppFlowy123".to_string(),
}
}
pub async fn register(&self, test_server: &TestServer) -> String {
let url = format!("{}/api/user/register", test_server.address);
let resp = test_server
.api_client
.post(&url)
.json(self)
.send()
.await
.expect("Fail to register user");
let bytes = resp.bytes().await.unwrap();
let response: RegisterResponse = serde_json::from_slice(&bytes).unwrap();
response.token
}
}
pub async fn error_msg_from_resp(resp: reqwest::Response) -> String {
let bytes = resp.bytes().await.unwrap();
String::from_utf8(bytes.to_vec()).unwrap()
}

13
tests/api/ws.rs Normal file
View file

@ -0,0 +1,13 @@
use crate::util::{spawn_server, TestUser};
use collab_client_ws::WSClient;
#[actix_rt::test]
async fn ws_conn_test() {
let server = spawn_server().await;
let test_user = TestUser::generate();
let token = test_user.register(&server).await;
let address = format!("{}/{}", server.ws_addr, token);
let client = WSClient::new(address, 100);
let _ = client.connect().await;
}

3
tests/main.rs Normal file
View file

@ -0,0 +1,3 @@
mod api;
mod util;
mod ws;

2
tests/util/mod.rs Normal file
View file

@ -0,0 +1,2 @@
mod test_server;
pub use test_server::*;

232
tests/util/test_server.rs Normal file
View file

@ -0,0 +1,232 @@
use appflowy_server::application::{init_state, Application};
use appflowy_server::config::config::{get_configuration, DatabaseSetting};
use appflowy_server::state::State;
use appflowy_server::telemetry::{get_subscriber, init_subscriber};
use once_cell::sync::Lazy;
use reqwest::Certificate;
use std::path::PathBuf;
use appflowy_server::component::auth::{RegisterResponse, HEADER_TOKEN};
use sqlx::types::Uuid;
use sqlx::{Connection, Executor, PgConnection, PgPool};
// Ensure that the `tracing` stack is only initialised once using `once_cell`
static TRACING: Lazy<()> = Lazy::new(|| {
let level = "trace".to_string();
let mut filters = vec![];
filters.push(format!("appflowy_server={}", level));
filters.push(format!("collab_client_ws={}", level));
filters.push(format!("hyper={}", level));
filters.push(format!("actix_web={}", level));
let subscriber_name = "test".to_string();
let subscriber = get_subscriber(subscriber_name, filters.join(","), std::io::stdout);
init_subscriber(subscriber);
});
#[derive(Clone)]
pub struct TestServer {
pub state: State,
pub api_client: reqwest::Client,
pub address: String,
pub port: u16,
pub ws_addr: String,
#[allow(dead_code)]
pub cleaner: Cleaner,
}
impl TestServer {
pub async fn register(&self, name: &str, email: &str, password: &str) -> reqwest::Response {
let payload = serde_json::json!({
"name": name,
"password": password,
"email": email
});
let url = format!("{}/api/user/register", self.address);
self
.api_client
.post(&url)
.json(&payload)
.send()
.await
.expect("Register failed")
}
pub async fn login(&self, email: &str, password: &str) -> reqwest::Response {
let payload = serde_json::json!({
"password": password,
"email": email
});
let url = format!("{}/api/user/login", self.address);
self
.api_client
.post(&url)
.json(&payload)
.send()
.await
.expect("Login failed")
}
pub async fn change_password(
&self,
token: String,
current_password: &str,
new_password: &str,
new_password_confirm: &str,
) -> reqwest::Response {
let payload = serde_json::json!({
"current_password": current_password,
"new_password": new_password,
"new_password_confirm": new_password_confirm
});
let url = format!("{}/api/user/password", self.address);
self
.api_client
.post(&url)
.header(HEADER_TOKEN, token)
.json(&payload)
.send()
.await
.expect("Change password failed")
}
}
pub async fn spawn_server() -> TestServer {
Lazy::force(&TRACING);
let database_name = Uuid::new_v4().to_string();
let config = {
let mut config = get_configuration().expect("Failed to read configuration.");
config.database.database_name = database_name.clone();
// Use a random OS port
config.application.port = 0;
config.application.data_dir = PathBuf::from(format!("./data/{}", database_name));
config
};
let _ = configure_database(&config.database).await;
let state = init_state(&config).await;
let application = Application::build(config.clone(), state.clone())
.await
.expect("Failed to build application");
let port = application.port();
let _ = tokio::spawn(async {
let _ = application.run_until_stopped().await;
});
let mut builder = reqwest::Client::builder();
let mut address = format!("http://localhost:{}", port);
let mut ws_addr = format!("ws://localhost:{}/ws", port);
if config.application.use_https() {
address = format!("https://localhost:{}", port);
ws_addr = format!("wss://localhost:{}/ws", port);
builder = builder
.add_root_certificate(Certificate::from_pem(include_bytes!("../../cert/cert.pem")).unwrap());
}
let api_client = builder
.add_root_certificate(Certificate::from_pem(include_bytes!("../../cert/cert.pem")).unwrap())
.redirect(reqwest::redirect::Policy::none())
.danger_accept_invalid_certs(true)
.cookie_store(true)
.no_proxy()
.build()
.unwrap();
let cleaner = Cleaner::new(config.application.data_dir);
TestServer {
state,
api_client,
address,
ws_addr,
port,
cleaner,
}
}
async fn configure_database(config: &DatabaseSetting) -> PgPool {
// Create database
let mut connection = PgConnection::connect_with(&config.without_db())
.await
.expect("Failed to connect to Postgres");
connection
.execute(&*format!(r#"CREATE DATABASE "{}";"#, config.database_name))
.await
.expect("Failed to create database.");
// Migrate database
let connection_pool = PgPool::connect_with(config.with_db())
.await
.expect("Failed to connect to Postgres.");
sqlx::migrate!("./migrations")
.run(&connection_pool)
.await
.expect("Failed to migrate the database");
connection_pool
}
#[derive(serde::Serialize)]
pub struct TestUser {
name: String,
pub email: String,
pub password: String,
}
impl TestUser {
pub fn generate() -> Self {
Self {
name: "Me".to_string(),
email: "me@appflowy.io".to_string(),
password: "Hello@AppFlowy123".to_string(),
}
}
pub async fn register(&self, test_server: &TestServer) -> String {
let url = format!("{}/api/user/register", test_server.address);
let resp = test_server
.api_client
.post(&url)
.json(self)
.send()
.await
.expect("Fail to register user");
let bytes = resp.bytes().await.unwrap();
let response: RegisterResponse = serde_json::from_slice(&bytes).unwrap();
response.token
}
}
pub async fn error_msg_from_resp(resp: reqwest::Response) -> String {
let bytes = resp.bytes().await.unwrap();
String::from_utf8(bytes.to_vec()).unwrap()
}
#[derive(Clone)]
pub struct Cleaner {
path: PathBuf,
should_clean: bool,
}
impl Cleaner {
fn new(path: PathBuf) -> Self {
Self {
path,
should_clean: true,
}
}
fn cleanup(dir: &PathBuf) {
let _ = std::fs::remove_dir_all(dir);
}
}
impl Drop for Cleaner {
fn drop(&mut self) {
if self.should_clean {
Self::cleanup(&self.path)
}
}
}

1
tests/ws/mod.rs Normal file
View file

@ -0,0 +1 @@