feat: ws connect (#3)

* chore: ws

* chore: build client stream

* feat: test ws connect

* ci: fix ci
This commit is contained in:
Nathan.fooo 2023-05-08 19:03:50 +08:00 committed by GitHub
parent 08847fad1d
commit 18e950a829
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
55 changed files with 2144 additions and 1868 deletions

4
.gitignore vendored
View file

@ -8,4 +8,6 @@
**/temp/** **/temp/**
package-lock.json package-lock.json
yarn.lock yarn.lock
node_modules node_modules
**/crates/AppFlowy-Collab/
data/

449
Cargo.lock generated
View file

@ -89,7 +89,7 @@ dependencies = [
"mime", "mime",
"percent-encoding", "percent-encoding",
"pin-project-lite", "pin-project-lite",
"rand", "rand 0.8.5",
"sha1", "sha1",
"smallvec", "smallvec",
"tokio", "tokio",
@ -189,7 +189,7 @@ dependencies = [
"anyhow", "anyhow",
"async-trait", "async-trait",
"derive_more", "derive_more",
"rand", "rand 0.8.5",
"redis", "redis",
"serde", "serde",
"serde_json", "serde_json",
@ -370,7 +370,7 @@ version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47"
dependencies = [ dependencies = [
"getrandom", "getrandom 0.2.9",
"once_cell", "once_cell",
"version_check", "version_check",
] ]
@ -382,7 +382,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"getrandom", "getrandom 0.2.9",
"once_cell", "once_cell",
"version_check", "version_check",
] ]
@ -446,6 +446,9 @@ dependencies = [
"bincode", "bincode",
"bytes", "bytes",
"chrono", "chrono",
"collab-client-ws",
"collab-persistence",
"collab-sync",
"config", "config",
"dashmap", "dashmap",
"derive_more", "derive_more",
@ -454,7 +457,7 @@ dependencies = [
"lazy_static", "lazy_static",
"once_cell", "once_cell",
"openssl", "openssl",
"rand", "rand 0.8.5",
"rcgen", "rcgen",
"reqwest", "reqwest",
"secrecy", "secrecy",
@ -474,6 +477,7 @@ dependencies = [
"unicode-segmentation", "unicode-segmentation",
"uuid", "uuid",
"validator", "validator",
"websocket",
] ]
[[package]] [[package]]
@ -574,6 +578,12 @@ dependencies = [
"num-traits", "num-traits",
] ]
[[package]]
name = "atomic_refcell"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79d6dc922a2792b006573f60b2648076355daeae5ce9cb59507e5908c9625d31"
[[package]] [[package]]
name = "autocfg" name = "autocfg"
version = "1.1.0" version = "1.1.0"
@ -613,6 +623,26 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "bindgen"
version = "0.64.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4243e6031260db77ede97ad86c27e501d646a27ab57b59a574f725d98ab1fb4"
dependencies = [
"bitflags",
"cexpr",
"clang-sys",
"lazy_static",
"lazycell",
"peeking_take_while",
"proc-macro2",
"quote",
"regex",
"rustc-hash",
"shlex",
"syn 1.0.109",
]
[[package]] [[package]]
name = "bit-set" name = "bit-set"
version = "0.5.3" version = "0.5.3"
@ -700,6 +730,17 @@ dependencies = [
"bytes", "bytes",
] ]
[[package]]
name = "bzip2-sys"
version = "0.1.11+1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc"
dependencies = [
"cc",
"libc",
"pkg-config",
]
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.0.79" version = "1.0.79"
@ -709,6 +750,15 @@ dependencies = [
"jobserver", "jobserver",
] ]
[[package]]
name = "cexpr"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
dependencies = [
"nom",
]
[[package]] [[package]]
name = "cfg-if" name = "cfg-if"
version = "1.0.0" version = "1.0.0"
@ -738,6 +788,17 @@ dependencies = [
"inout", "inout",
] ]
[[package]]
name = "clang-sys"
version = "1.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c688fc74432808e3eb684cae8830a86be1d66a2bd58e1f248ed0960a590baf6f"
dependencies = [
"glob",
"libc",
"libloading",
]
[[package]] [[package]]
name = "codespan-reporting" name = "codespan-reporting"
version = "0.11.1" version = "0.11.1"
@ -748,6 +809,91 @@ dependencies = [
"unicode-width", "unicode-width",
] ]
[[package]]
name = "collab"
version = "0.1.0"
dependencies = [
"anyhow",
"bytes",
"lib0",
"parking_lot 0.12.1",
"serde",
"serde_json",
"thiserror",
"tracing",
"y-sync",
"yrs",
]
[[package]]
name = "collab-client-ws"
version = "0.1.0"
dependencies = [
"bytes",
"futures-util",
"serde",
"serde_json",
"thiserror",
"tokio",
"tokio-retry",
"tokio-stream",
"tokio-tungstenite",
"tracing",
]
[[package]]
name = "collab-persistence"
version = "0.1.0"
dependencies = [
"bincode",
"chrono",
"lazy_static",
"lib0",
"parking_lot 0.12.1",
"rocksdb",
"serde",
"sled",
"smallvec",
"thiserror",
"tokio",
"tracing",
"yrs",
]
[[package]]
name = "collab-plugins"
version = "0.1.0"
dependencies = [
"collab",
"collab-client-ws",
"collab-persistence",
"collab-sync",
"tracing",
"y-sync",
"yrs",
]
[[package]]
name = "collab-sync"
version = "0.1.0"
dependencies = [
"bytes",
"collab",
"futures-util",
"lib0",
"md5",
"parking_lot 0.12.1",
"serde",
"serde_json",
"thiserror",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
"y-sync",
"yrs",
]
[[package]] [[package]]
name = "combine" name = "combine"
version = "4.6.6" version = "4.6.6"
@ -793,7 +939,7 @@ dependencies = [
"hkdf", "hkdf",
"hmac", "hmac",
"percent-encoding", "percent-encoding",
"rand", "rand 0.8.5",
"sha2", "sha2",
"subtle", "subtle",
"time", "time",
@ -914,7 +1060,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
dependencies = [ dependencies = [
"generic-array", "generic-array",
"rand_core", "rand_core 0.6.4",
"typenum", "typenum",
] ]
@ -1308,6 +1454,19 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "getrandom"
version = "0.1.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce"
dependencies = [
"cfg-if",
"js-sys",
"libc",
"wasi 0.9.0+wasi-snapshot-preview1",
"wasm-bindgen",
]
[[package]] [[package]]
name = "getrandom" name = "getrandom"
version = "0.2.9" version = "0.2.9"
@ -1316,7 +1475,7 @@ checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"libc", "libc",
"wasi", "wasi 0.11.0+wasi-snapshot-preview1",
] ]
[[package]] [[package]]
@ -1329,6 +1488,12 @@ dependencies = [
"polyval", "polyval",
] ]
[[package]]
name = "glob"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
[[package]] [[package]]
name = "h2" name = "h2"
version = "0.3.18" version = "0.3.18"
@ -1635,12 +1800,65 @@ version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]]
name = "lazycell"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]]
name = "lib0"
version = "0.16.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "daf23122cb1c970b77ea6030eac5e328669415b65d2ab245c99bfb110f9d62dc"
dependencies = [
"serde",
"serde_json",
"thiserror",
]
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.142" version = "0.2.142"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a987beff54b60ffa6d51982e1aa1146bc42f19bd26be28b0586f252fccf5317" checksum = "6a987beff54b60ffa6d51982e1aa1146bc42f19bd26be28b0586f252fccf5317"
[[package]]
name = "libloading"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b67380fd3b2fbe7527a606e18729d21c6f3951633d0500574c4dc22d2d638b9f"
dependencies = [
"cfg-if",
"winapi",
]
[[package]]
name = "librocksdb-sys"
version = "0.10.0+7.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fe4d5874f5ff2bc616e55e8c6086d478fcda13faf9495768a4aa1c22042d30b"
dependencies = [
"bindgen",
"bzip2-sys",
"cc",
"glob",
"libc",
"libz-sys",
"zstd-sys",
]
[[package]]
name = "libz-sys"
version = "1.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56ee889ecc9568871456d42f603d6a0ce59ff328d291063a45cbdf0036baf6db"
dependencies = [
"cc",
"pkg-config",
"vcpkg",
]
[[package]] [[package]]
name = "link-cplusplus" name = "link-cplusplus"
version = "1.0.8" version = "1.0.8"
@ -1723,6 +1941,12 @@ dependencies = [
"digest", "digest",
] ]
[[package]]
name = "md5"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771"
[[package]] [[package]]
name = "memchr" name = "memchr"
version = "2.5.0" version = "2.5.0"
@ -1767,7 +1991,7 @@ checksum = "5b9d9a46eff5b4ff64b45a9e316a6d1e0bc719ef429cbec4dc630684212bfdf9"
dependencies = [ dependencies = [
"libc", "libc",
"log", "log",
"wasi", "wasi 0.11.0+wasi-snapshot-preview1",
"windows-sys 0.45.0", "windows-sys 0.45.0",
] ]
@ -1975,7 +2199,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166" checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166"
dependencies = [ dependencies = [
"base64ct", "base64ct",
"rand_core", "rand_core 0.6.4",
"subtle", "subtle",
] ]
@ -1991,6 +2215,12 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8835116a5c179084a830efb3adc117ab007512b535bc1a21c991d3b32a6b44dd" checksum = "8835116a5c179084a830efb3adc117ab007512b535bc1a21c991d3b32a6b44dd"
[[package]]
name = "peeking_take_while"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099"
[[package]] [[package]]
name = "pem" name = "pem"
version = "1.1.1" version = "1.1.1"
@ -2096,6 +2326,19 @@ dependencies = [
"proc-macro2", "proc-macro2",
] ]
[[package]]
name = "rand"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03"
dependencies = [
"getrandom 0.1.16",
"libc",
"rand_chacha 0.2.2",
"rand_core 0.5.1",
"rand_hc",
]
[[package]] [[package]]
name = "rand" name = "rand"
version = "0.8.5" version = "0.8.5"
@ -2103,8 +2346,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [ dependencies = [
"libc", "libc",
"rand_chacha", "rand_chacha 0.3.1",
"rand_core", "rand_core 0.6.4",
]
[[package]]
name = "rand_chacha"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402"
dependencies = [
"ppv-lite86",
"rand_core 0.5.1",
] ]
[[package]] [[package]]
@ -2114,7 +2367,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [ dependencies = [
"ppv-lite86", "ppv-lite86",
"rand_core", "rand_core 0.6.4",
]
[[package]]
name = "rand_core"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
dependencies = [
"getrandom 0.1.16",
] ]
[[package]] [[package]]
@ -2123,7 +2385,16 @@ version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [ dependencies = [
"getrandom", "getrandom 0.2.9",
]
[[package]]
name = "rand_hc"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
dependencies = [
"rand_core 0.5.1",
] ]
[[package]] [[package]]
@ -2186,7 +2457,7 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b"
dependencies = [ dependencies = [
"getrandom", "getrandom 0.2.9",
"redox_syscall 0.2.16", "redox_syscall 0.2.16",
"thiserror", "thiserror",
] ]
@ -2264,17 +2535,6 @@ dependencies = [
"winreg", "winreg",
] ]
[[package]]
name = "revdb"
version = "0.1.0"
dependencies = [
"bincode",
"serde",
"sled",
"tempfile",
"thiserror",
]
[[package]] [[package]]
name = "ring" name = "ring"
version = "0.16.20" version = "0.16.20"
@ -2290,6 +2550,22 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "rocksdb"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "015439787fce1e75d55f279078d33ff14b4af5d93d995e8838ee4631301c8a99"
dependencies = [
"libc",
"librocksdb-sys",
]
[[package]]
name = "rustc-hash"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]] [[package]]
name = "rustc_version" name = "rustc_version"
version = "0.4.0" version = "0.4.0"
@ -2504,6 +2780,12 @@ dependencies = [
"lazy_static", "lazy_static",
] ]
[[package]]
name = "shlex"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43b2853a4d09f215c24cc5489c992ce46052d359b5109343cbafbf26bc62f8a3"
[[package]] [[package]]
name = "signal-hook-registry" name = "signal-hook-registry"
version = "1.4.1" version = "1.4.1"
@ -2538,6 +2820,15 @@ dependencies = [
"parking_lot 0.11.2", "parking_lot 0.11.2",
] ]
[[package]]
name = "smallstr"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e922794d168678729ffc7e07182721a14219c65814e66e91b839a272fe5ae4f"
dependencies = [
"smallvec",
]
[[package]] [[package]]
name = "smallvec" name = "smallvec"
version = "1.10.0" version = "1.10.0"
@ -2621,7 +2912,7 @@ dependencies = [
"once_cell", "once_cell",
"paste", "paste",
"percent-encoding", "percent-encoding",
"rand", "rand 0.8.5",
"rustls", "rustls",
"rustls-pemfile", "rustls-pemfile",
"serde", "serde",
@ -2881,6 +3172,17 @@ dependencies = [
"tokio", "tokio",
] ]
[[package]]
name = "tokio-retry"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f57eb36ecbe0fc510036adff84824dd3c24bb781e21bfa67b69d556aa85214f"
dependencies = [
"pin-project",
"rand 0.8.5",
"tokio",
]
[[package]] [[package]]
name = "tokio-rustls" name = "tokio-rustls"
version = "0.23.4" version = "0.23.4"
@ -2901,6 +3203,19 @@ dependencies = [
"futures-core", "futures-core",
"pin-project-lite", "pin-project-lite",
"tokio", "tokio",
"tokio-util",
]
[[package]]
name = "tokio-tungstenite"
version = "0.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54319c93411147bced34cb5609a80e0a8e44c5999c93903a81cd866630ec0bfd"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite",
] ]
[[package]] [[package]]
@ -3035,6 +3350,25 @@ version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed"
[[package]]
name = "tungstenite"
version = "0.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30ee6ab729cd4cf0fd55218530c4522ed30b7b6081752839b68fcec8d0960788"
dependencies = [
"base64 0.13.1",
"byteorder",
"bytes",
"http",
"httparse",
"log",
"rand 0.8.5",
"sha1",
"thiserror",
"url",
"utf-8",
]
[[package]] [[package]]
name = "typenum" name = "typenum"
version = "1.16.0" version = "1.16.0"
@ -3113,13 +3447,19 @@ dependencies = [
"percent-encoding", "percent-encoding",
] ]
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]] [[package]]
name = "uuid" name = "uuid"
version = "1.3.2" version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4dad5567ad0cf5b760e5665964bec1b47dfd077ba8a2544b513f3556d3d239a2" checksum = "4dad5567ad0cf5b760e5665964bec1b47dfd077ba8a2544b513f3556d3d239a2"
dependencies = [ dependencies = [
"getrandom", "getrandom 0.2.9",
"serde", "serde",
] ]
@ -3166,6 +3506,12 @@ dependencies = [
"try-lock", "try-lock",
] ]
[[package]]
name = "wasi"
version = "0.9.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519"
[[package]] [[package]]
name = "wasi" name = "wasi"
version = "0.11.0+wasi-snapshot-preview1" version = "0.11.0+wasi-snapshot-preview1"
@ -3267,6 +3613,28 @@ dependencies = [
"webpki", "webpki",
] ]
[[package]]
name = "websocket"
version = "0.1.0"
dependencies = [
"actix",
"actix-web-actors",
"bytes",
"collab",
"collab-persistence",
"collab-plugins",
"collab-sync",
"dashmap",
"futures-util",
"parking_lot 0.12.1",
"secrecy",
"serde",
"thiserror",
"tokio",
"tokio-stream",
"tracing",
]
[[package]] [[package]]
name = "whoami" name = "whoami"
version = "1.4.0" version = "1.4.0"
@ -3492,6 +3860,17 @@ dependencies = [
"time", "time",
] ]
[[package]]
name = "y-sync"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f54d34b68ec4514a0659838c2b1ba867c571b20b3804a1338dacf4fa9062d801"
dependencies = [
"lib0",
"thiserror",
"yrs",
]
[[package]] [[package]]
name = "yaml-rust" name = "yaml-rust"
version = "0.4.5" version = "0.4.5"
@ -3510,6 +3889,20 @@ dependencies = [
"time", "time",
] ]
[[package]]
name = "yrs"
version = "0.16.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c2aef2bf89b4f7c003f9c73f1c8097427ca32e1d006443f3f607f11e79a797b"
dependencies = [
"atomic_refcell",
"lib0",
"rand 0.7.3",
"smallstr",
"smallvec",
"thiserror",
]
[[package]] [[package]]
name = "zeroize" name = "zeroize"
version = "1.6.0" version = "1.6.0"

View file

@ -51,6 +51,8 @@ bytes = "1.4.0"
bincode = "1.3.3" bincode = "1.3.3"
dashmap = "5.4" dashmap = "5.4"
rcgen = { version = "0.10.0", features = ["pem", "x509-parser"] } rcgen = { version = "0.10.0", features = ["pem", "x509-parser"] }
collab-sync = {version = "0.1.0"}
collab-persistence = {version = "0.1.0"}
# tracing # tracing
tracing = { version = "0.1.37" } tracing = { version = "0.1.37" }
@ -63,9 +65,11 @@ sqlx = { version = "0.6", default-features = false, features = ["runtime-actix-r
#Local crate #Local crate
token = { path = "./crates/token" } token = { path = "./crates/token" }
snowflake = { path = "./crates/snowflake" } snowflake = { path = "./crates/snowflake" }
websocket = { path = "./crates/websocket" }
[dev-dependencies] [dev-dependencies]
once_cell = "1.7.2" once_cell = "1.7.2"
collab-client-ws = { version = "0.1.0" }
[[bin]] [[bin]]
name = "appflowy_server" name = "appflowy_server"
@ -78,6 +82,19 @@ path = "src/lib.rs"
[workspace] [workspace]
members = [ members = [
"crates/token", "crates/token",
"crates/revdb",
"crates/snowflake", "crates/snowflake",
"crates/websocket",
] ]
[patch.crates-io]
collab = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a12ed" }
collab-client-ws = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a12ed" }
collab-sync= { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a12ed" }
collab-persistence = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a12ed" }
collab-plugins = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a12ed" }
#collab = { path = "./crates/AppFlowy-Collab/collab" }
#collab-client-ws = { path = "./crates/AppFlowy-Collab/collab-client-ws" }
#collab-sync = { path = "./crates/AppFlowy-Collab/collab-sync" }
#collab-persistence = { path = "./crates/AppFlowy-Collab/collab-persistence" }
#collab-plugins = { path = "./crates/AppFlowy-Collab/collab-plugins"}

View file

@ -3,6 +3,7 @@ application:
host: 0.0.0.0 host: 0.0.0.0
server_key: "Should-Use-The-Custom-Server-Key" server_key: "Should-Use-The-Custom-Server-Key"
tls_config: "no_tls" tls_config: "no_tls"
data_dir: "./data"
database: database:
host: "localhost" host: "localhost"
port: 5432 port: 5432

View file

@ -1,15 +0,0 @@
[package]
name = "revdb"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
sled = "0.34.7"
thiserror = "1.0.30"
serde = { version = "1.0", features = ["derive"] }
bincode = "1.3.3"
[dev-dependencies]
tempfile = "3.4.0"

View file

@ -1,56 +0,0 @@
use crate::document::Document;
use crate::error::RevDBError;
use sled::{Batch, Db, IVec};
use std::path::Path;
pub struct RevDB {
pub(crate) db: Db,
}
impl RevDB {
pub fn open(path: impl AsRef<Path>) -> Result<Self, RevDBError> {
let db = sled::open(path)?;
Ok(Self { db })
}
pub fn document(&self) -> Document {
Document { db: self }
}
pub fn get<K: AsRef<[u8]>>(&self, key: K) -> Result<Option<IVec>, RevDBError> {
let value = self.db.get(key)?;
Ok(value)
}
pub fn batch_get<K: AsRef<[u8]>>(
&self,
from_key: K,
to_key: K,
) -> Result<Vec<IVec>, RevDBError> {
let iter = self.db.range(from_key..=to_key);
let mut items = vec![];
for item in iter {
let (_, value) = item?;
items.push(value)
}
Ok(items)
}
pub fn insert<K: AsRef<[u8]>>(&self, key: K, value: &[u8]) -> Result<(), RevDBError> {
let _ = self.db.insert(key, value)?;
Ok(())
}
pub fn batch_insert<'a, K: AsRef<[u8]>>(
&self,
items: impl IntoIterator<Item = (K, &'a [u8])>,
) -> Result<(), RevDBError> {
let mut batch = Batch::default();
let items = items.into_iter();
items.for_each(|(key, value)| {
batch.insert(key.as_ref(), value);
});
self.db.apply_batch(batch)?;
Ok(())
}
}

View file

@ -1,117 +0,0 @@
use crate::db::RevDB;
use crate::error::RevDBError;
use crate::range::RevRange;
use serde::{Deserialize, Serialize};
pub struct Document<'a> {
pub(crate) db: &'a RevDB,
}
impl<'a> Document<'a> {
pub fn insert(
&self,
uid: i64,
document_id: i64,
value: DocumentRevData,
) -> Result<(), RevDBError> {
let key = make_document_key(uid, document_id, value.rev_id);
self.db.insert(key, &value.to_vec()?)?;
Ok(())
}
pub fn get(
&self,
uid: i64,
document_id: i64,
rev_id: i64,
) -> Result<Option<DocumentRevData>, RevDBError> {
let key = make_document_key(uid, document_id, rev_id);
match self.db.get(key)? {
None => Ok(None),
Some(value) => {
let data = DocumentRevData::from_vec(value.as_ref())?;
Ok(Some(data))
}
}
}
pub fn get_with_range<R: Into<RevRange>>(
&self,
uid: i64,
document_id: i64,
range: R,
) -> Result<Vec<DocumentRevData>, RevDBError> {
let range = range.into();
let from = make_document_key(uid, document_id, range.start);
let to = make_document_key(uid, document_id, range.end);
self.batch_get(from, to)
}
pub fn get_after(
&self,
uid: i64,
document_id: i64,
rev_id: i64,
) -> Result<Vec<DocumentRevData>, RevDBError> {
let from = make_document_key(uid, document_id, rev_id);
let to = make_document_key(uid, document_id, i64::MAX);
self.batch_get(from, to)
}
pub fn get_before(
&self,
uid: i64,
document_id: i64,
rev_id: i64,
) -> Result<Vec<DocumentRevData>, RevDBError> {
let from = make_document_key(uid, document_id, 0);
let to = make_document_key(uid, document_id, rev_id);
self.batch_get(from, to)
}
fn batch_get<K: AsRef<[u8]>>(
&self,
from: K,
to: K,
) -> Result<Vec<DocumentRevData>, RevDBError> {
let items = self.db.batch_get(from, to)?;
let mut document_revs = vec![];
for item in items {
let rev_data = DocumentRevData::from_vec(item.as_ref())?;
document_revs.push(rev_data);
}
Ok(document_revs)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DocumentRevData {
#[serde(rename = "rid")]
pub rev_id: i64,
#[serde(rename = "bid")]
pub base_rev_id: i64,
#[serde(rename = "data")]
pub content: String,
}
impl DocumentRevData {
pub fn from_vec(data: &[u8]) -> Result<Self, RevDBError> {
bincode::deserialize::<Self>(data).map_err(|_e| RevDBError::SerdeError)
}
pub fn to_vec(&self) -> Result<Vec<u8>, RevDBError> {
bincode::serialize(self).map_err(|_e| RevDBError::SerdeError)
}
}
// Optimize your data layout: Sled's B-Tree implementation works best when the keys are sequential,
// so try to organize the data in a way that maximizes sequential access.
fn make_document_key(uid: i64, document_id: i64, rev_id: i64) -> [u8; 24] {
let mut key = [0; 24];
key[0..8].copy_from_slice(&uid.to_be_bytes());
key[8..16].copy_from_slice(&document_id.to_be_bytes());
key[16..24].copy_from_slice(&rev_id.to_be_bytes());
key
}

View file

@ -1,11 +0,0 @@
#[derive(Debug, thiserror::Error)]
pub enum RevDBError {
#[error(transparent)]
Db(#[from] sled::Error),
#[error("Serde error")]
SerdeError,
#[error("invalid data")]
InvalidData,
}

View file

@ -1,4 +0,0 @@
pub mod db;
pub mod document;
pub mod error;
pub mod range;

View file

@ -1,48 +0,0 @@
use std::ops::{Range, RangeInclusive, RangeToInclusive};
#[derive(Clone)]
pub struct RevRange {
pub(crate) start: i64,
pub(crate) end: i64,
}
impl RevRange {
/// Construct a new `RevRange` representing the range [start..end).
/// It is an invariant that `start <= end`.
pub fn new(start: i64, end: i64) -> RevRange {
debug_assert!(start <= end);
RevRange { start, end }
}
}
impl From<RangeInclusive<i64>> for RevRange {
fn from(src: RangeInclusive<i64>) -> RevRange {
RevRange::new(*src.start(), src.end().saturating_add(1))
}
}
impl From<RangeToInclusive<i64>> for RevRange {
fn from(src: RangeToInclusive<i64>) -> RevRange {
RevRange::new(0, src.end.saturating_add(1))
}
}
impl From<Range<i64>> for RevRange {
fn from(src: Range<i64>) -> RevRange {
let Range { start, end } = src;
RevRange { start, end }
}
}
impl Iterator for RevRange {
type Item = i64;
fn next(&mut self) -> Option<i64> {
if self.start > self.end {
return None;
}
let val = self.start;
self.start += 1;
Some(val)
}
}

View file

@ -1,2 +0,0 @@
mod test;
mod util;

View file

@ -1,136 +0,0 @@
use crate::document::util::make_test_db;
use revdb::document::{Document, DocumentRevData};
use revdb::range::RevRange;
#[test]
fn insert_text() {
let db = make_test_db();
let document = db.document();
let uid = 12345678;
let document_id = 1;
let value = DocumentRevData {
rev_id: 0,
base_rev_id: 0,
content: "hello world".to_string(),
};
document.insert(uid, document_id, value.clone()).unwrap();
let restored_data = document.get(uid, document_id, 0).unwrap().unwrap();
assert_eq!(value.content, restored_data.content);
}
//noinspection RsExternalLinter
#[test]
fn insert_multi_text() {
let db = make_test_db();
let document = db.document();
let uid = 12345678;
let document_id = 1;
let mut base_rev_id = 0;
let mut expected_str = "".to_string();
for i in 0..=100 {
let content = i.to_string();
expected_str.push_str(&content);
let value = DocumentRevData {
rev_id: i,
base_rev_id,
content,
};
base_rev_id += 1;
document.insert(uid, document_id, value).unwrap();
}
//
let restored_str = document
.get_with_range(uid, document_id, RevRange::new(0, 100))
.unwrap()
.into_iter()
.map(|data| data.content)
.collect::<Vec<String>>()
.join("");
assert_eq!(expected_str, restored_str);
}
//noinspection RsExternalLinter
fn insert_100_string_to_document(uid: i64, document_id: i64, document: &Document) {
let mut base_rev_id = 0;
for i in 0..=100 {
let content = i.to_string();
let value = DocumentRevData {
rev_id: i,
base_rev_id,
content,
};
base_rev_id += 1;
document.insert(uid, document_id, value).unwrap();
}
}
fn values_to_string(values: Vec<DocumentRevData>) -> String {
values
.into_iter()
.map(|data| data.content)
.collect::<Vec<String>>()
.join("")
}
#[test]
fn get_value_before() {
let db = make_test_db();
let document = db.document();
let uid = 12345678;
let document_id = 1;
insert_100_string_to_document(uid, document_id, &document);
let restored_str = values_to_string(document.get_before(uid, document_id, 50).unwrap());
assert_eq!("01234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950", restored_str);
let restored_str = values_to_string(document.get_before(uid, document_id, 0).unwrap());
assert_eq!("0", restored_str);
}
#[test]
fn get_value_after() {
let db = make_test_db();
let document = db.document();
let uid = 12345678;
let document_id = 1;
insert_100_string_to_document(uid, document_id, &document);
let restored_str = values_to_string(document.get_after(uid, document_id, 50).unwrap());
assert_eq!("5051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100", restored_str);
let restored_str = values_to_string(document.get_after(uid, document_id, 100).unwrap());
assert_eq!("100", restored_str);
}
#[test]
fn get_value_with_range() {
let db = make_test_db();
let document = db.document();
let uid = 12345678;
let document_id = 1;
insert_100_string_to_document(uid, document_id, &document);
let restored_str = values_to_string(
document
.get_with_range(uid, document_id, RevRange::new(50, 60))
.unwrap(),
);
assert_eq!("5051525354555657585960", restored_str);
let restored_str = values_to_string(
document
.get_with_range(uid, document_id, RevRange::new(50, 200))
.unwrap(),
);
assert_eq!("5051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100", restored_str);
let restored_str = values_to_string(
document
.get_with_range(uid, document_id, RevRange::new(50, 50))
.unwrap(),
);
assert_eq!("50", restored_str);
}

View file

@ -1,7 +0,0 @@
use revdb::db::RevDB;
use tempfile::TempDir;
pub fn make_test_db() -> RevDB {
let tempdir = TempDir::new().unwrap();
RevDB::open(tempdir).unwrap()
}

View file

@ -1 +0,0 @@
mod document;

View file

@ -8,66 +8,65 @@ const TIMESTAMP_SHIFT: u64 = NODE_ID_BITS + SEQUENCE_BITS;
const SEQUENCE_MASK: u64 = (1 << SEQUENCE_BITS) - 1; const SEQUENCE_MASK: u64 = (1 << SEQUENCE_BITS) - 1;
pub struct Snowflake { pub struct Snowflake {
node_id: u64, node_id: u64,
sequence: u64, sequence: u64,
last_timestamp: u64, last_timestamp: u64,
} }
impl Snowflake { impl Snowflake {
pub fn new(node_id: u64) -> Snowflake { pub fn new(node_id: u64) -> Snowflake {
Snowflake { Snowflake {
node_id, node_id,
sequence: 0, sequence: 0,
last_timestamp: 0, last_timestamp: 0,
} }
}
pub fn next_id(&mut self) -> i64 {
let timestamp = self.timestamp();
if timestamp < self.last_timestamp {
panic!("Clock moved backwards!");
} }
pub fn next_id(&mut self) -> i64 { if timestamp == self.last_timestamp {
let timestamp = self.timestamp(); self.sequence = (self.sequence + 1) & SEQUENCE_MASK;
if timestamp < self.last_timestamp { if self.sequence == 0 {
panic!("Clock moved backwards!"); self.wait_next_millis();
} }
} else {
if timestamp == self.last_timestamp { self.sequence = 0;
self.sequence = (self.sequence + 1) & SEQUENCE_MASK;
if self.sequence == 0 {
self.wait_next_millis();
}
} else {
self.sequence = 0;
}
self.last_timestamp = timestamp;
let id =
(timestamp - EPOCH) << TIMESTAMP_SHIFT | self.node_id << NODE_ID_SHIFT | self.sequence;
id as i64
} }
fn wait_next_millis(&self) { self.last_timestamp = timestamp;
let mut timestamp = self.timestamp(); let id = (timestamp - EPOCH) << TIMESTAMP_SHIFT | self.node_id << NODE_ID_SHIFT | self.sequence;
while timestamp == self.last_timestamp { id as i64
timestamp = self.timestamp(); }
}
}
fn timestamp(&self) -> u64 { fn wait_next_millis(&self) {
SystemTime::now() let mut timestamp = self.timestamp();
.duration_since(SystemTime::UNIX_EPOCH) while timestamp == self.last_timestamp {
.expect("Clock moved backwards!") timestamp = self.timestamp();
.as_millis() as u64
} }
}
fn timestamp(&self) -> u64 {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.expect("Clock moved backwards!")
.as_millis() as u64
}
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::Snowflake; use crate::Snowflake;
#[test] #[test]
fn gen_id() { fn gen_id() {
let mut snow_flake = Snowflake::new(1); let mut snow_flake = Snowflake::new(1);
let id_1 = snow_flake.next_id(); let id_1 = snow_flake.next_id();
let id_2 = snow_flake.next_id(); let id_2 = snow_flake.next_id();
assert_ne!(id_1, id_2); assert_ne!(id_1, id_2);
} }
} }

View file

@ -7,70 +7,72 @@ use sha2::Sha256;
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum TokenError { pub enum TokenError {
#[error(transparent)] #[error(transparent)]
Jwt(#[from] jwt::Error), Jwt(#[from] jwt::Error),
#[error("Token expired")] #[error("Token expired")]
Expired, Expired,
} }
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)] #[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
pub enum TokenType { pub enum TokenType {
AccessToken, AccessToken,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
struct TokenFields<T> { struct TokenFields<T> {
#[serde(rename = "d")] #[serde(rename = "d")]
data: T, data: T,
#[serde(rename = "exp")] #[serde(rename = "exp")]
expire_at: DateTime<Utc>, expire_at: DateTime<Utc>,
} }
pub fn create_token( pub fn create_token(
server_key: &str, server_key: &str,
data: impl Serialize, data: impl Serialize,
expire_duration: Duration, expire_duration: Duration,
) -> Result<String, TokenError> { ) -> Result<String, TokenError> {
Ok(TokenFields { Ok(
data, TokenFields {
expire_at: Utc::now() + expire_duration, data,
expire_at: Utc::now() + expire_duration,
} }
.sign_with_key(&generate_hmac_key(server_key))?) .sign_with_key(&generate_hmac_key(server_key))?,
)
} }
fn generate_hmac_key(server_key: &str) -> Hmac<Sha256> { fn generate_hmac_key(server_key: &str) -> Hmac<Sha256> {
Hmac::<Sha256>::new_from_slice(server_key.as_bytes()).expect("invalid server key") Hmac::<Sha256>::new_from_slice(server_key.as_bytes()).expect("invalid server key")
} }
pub fn parse_token<T: DeserializeOwned>(server_key: &str, token: &str) -> Result<T, TokenError> { pub fn parse_token<T: DeserializeOwned>(server_key: &str, token: &str) -> Result<T, TokenError> {
let fields = let fields =
VerifyWithKey::<TokenFields<T>>::verify_with_key(token, &generate_hmac_key(server_key))?; VerifyWithKey::<TokenFields<T>>::verify_with_key(token, &generate_hmac_key(server_key))?;
if fields.expire_at < Utc::now() { if fields.expire_at < Utc::now() {
return Err(TokenError::Expired); return Err(TokenError::Expired);
} }
Ok(fields.data) Ok(fields.data)
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test] #[test]
fn create_token_test() { fn create_token_test() {
let token_data = "hello appflowy".to_string(); let token_data = "hello appflowy".to_string();
let token = create_token("server_key", &token_data, Duration::days(2)).unwrap(); let token = create_token("server_key", &token_data, Duration::days(2)).unwrap();
let parse_token_data = parse_token::<String>("server_key", &token).unwrap(); let parse_token_data = parse_token::<String>("server_key", &token).unwrap();
assert_eq!(token_data, parse_token_data); assert_eq!(token_data, parse_token_data);
} }
#[test] #[test]
#[should_panic] #[should_panic]
fn parser_token_with_different_server_key() { fn parser_token_with_different_server_key() {
let server_key = "123456"; let server_key = "123456";
let token = create_token(server_key, "hello", Duration::days(2)).unwrap(); let token = create_token(server_key, "hello", Duration::days(2)).unwrap();
let _ = parse_token::<String>("abcdef", &token).unwrap(); let _ = parse_token::<String>("abcdef", &token).unwrap();
} }
} }

View file

@ -0,0 +1,25 @@
[package]
name = "websocket"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
actix = "0.13"
actix-web-actors = { version = "4.2.0" }
serde = { version = "1.0", features = ["derive"] }
thiserror = "1.0.30"
bytes = "1.0"
secrecy = { version = "0.8", features = ["serde"] }
parking_lot = "0.12.1"
tracing = "0.1.25"
futures-util = "0.3.26"
tokio-stream = { version = "0.1.14", features = ["sync"] }
tokio = { version = "1.26", features = ["sync"] }
dashmap = "5.4.0"
collab = { version = "0.1.0"}
collab-sync = { version = "0.1.0"}
collab-persistence = { version = "0.1.0"}
collab-plugins = { version = "0.1.0", features = ["disk_rocksdb"]}

View file

@ -0,0 +1,168 @@
use crate::entities::{ClientMessage, Connect, Disconnect, ServerMessage, WSUser};
use crate::error::WSError;
use crate::CollabServer;
use actix::{
fut, Actor, ActorContext, ActorFutureExt, Addr, AsyncContext, ContextFutureSpawner, Handler,
Recipient, Running, StreamHandler, WrapFuture,
};
use actix_web_actors::ws;
use bytes::Bytes;
use collab_sync::msg::CollabMessage;
use futures_util::Sink;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
const CLIENT_TIMEOUT: Duration = Duration::from_secs(10);
pub struct CollabSession {
user: Arc<WSUser>,
hb: Instant,
pub server: Addr<CollabServer>,
}
impl CollabSession {
pub fn new(user: WSUser, server: Addr<CollabServer>) -> Self {
Self {
user: Arc::new(user),
hb: Instant::now(),
server,
}
}
fn hb(&self, ctx: &mut ws::WebsocketContext<Self>) {
ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| {
if Instant::now().duration_since(act.hb) > CLIENT_TIMEOUT {
act.server.do_send(Disconnect {
user: act.user.clone(),
});
ctx.stop();
return;
}
ctx.ping(b"");
});
}
fn send_to_server(&self, bytes: Bytes) {
match CollabMessage::from_vec(bytes.to_vec()) {
Ok(collab_msg) => {
self.server.do_send(ClientMessage {
user: self.user.clone(),
collab_msg,
});
},
Err(e) => {
tracing::error!("Error parsing message: {:?}", e);
},
}
}
}
impl Actor for CollabSession {
type Context = ws::WebsocketContext<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
// start heartbeats otherwise server disconnects in 10 seconds
self.hb(ctx);
self
.server
.send(Connect {
socket: ctx.address().recipient(),
user: self.user.clone(),
})
.into_actor(self)
.then(|res, _session, ctx| {
match res {
Ok(Ok(_)) => {
tracing::trace!("Send connect message to server success")
},
_ => {
tracing::error!("Send connect message to server failed");
ctx.stop();
},
}
fut::ready(())
})
.wait(ctx);
}
fn stopping(&mut self, _: &mut Self::Context) -> Running {
self.server.do_send(Disconnect {
user: self.user.clone(),
});
Running::Stop
}
}
impl Handler<ServerMessage> for CollabSession {
type Result = ();
fn handle(&mut self, msg: ServerMessage, ctx: &mut Self::Context) {
ctx.binary(msg.collab_msg);
}
}
/// WebSocket message handler
impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for CollabSession {
fn handle(&mut self, msg: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
let msg = match msg {
Err(_) => {
ctx.stop();
return;
},
Ok(msg) => msg,
};
match msg {
ws::Message::Ping(msg) => {
self.hb = Instant::now();
ctx.pong(&msg);
},
ws::Message::Pong(_) => {
self.hb = Instant::now();
},
ws::Message::Text(_) => {},
ws::Message::Binary(bytes) => {
self.send_to_server(bytes);
},
ws::Message::Close(reason) => {
ctx.close(reason);
ctx.stop();
},
ws::Message::Continuation(_) => {
ctx.stop();
},
ws::Message::Nop => (),
}
}
}
/// A helper struct that wraps the [Recipient] type to implement the [Sink] trait
pub struct ClientSink(pub Recipient<ServerMessage>);
impl Sink<CollabMessage> for ClientSink {
type Error = WSError;
fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn start_send(self: Pin<&mut Self>, item: CollabMessage) -> Result<(), Self::Error> {
self.0.do_send(ServerMessage { collab_msg: item });
Ok(())
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
}

View file

@ -0,0 +1,56 @@
use crate::error::WSError;
use actix::{Message, Recipient};
use collab_sync::msg::CollabMessage;
use secrecy::{ExposeSecret, Secret};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct WSUser {
pub user_id: Secret<String>,
}
impl Hash for WSUser {
fn hash<H: Hasher>(&self, state: &mut H) {
let uid: &String = self.user_id.expose_secret();
uid.hash(state);
}
}
impl PartialEq<Self> for WSUser {
fn eq(&self, other: &Self) -> bool {
let uid: &String = self.user_id.expose_secret();
let other_uid: &String = other.user_id.expose_secret();
uid == other_uid
}
}
impl Eq for WSUser {}
#[derive(Debug, Message, Clone)]
#[rtype(result = "Result<(), WSError>")]
pub struct Connect {
pub socket: Recipient<ServerMessage>,
pub user: Arc<WSUser>,
}
#[derive(Debug, Message, Clone)]
#[rtype(result = "Result<(), WSError>")]
pub struct Disconnect {
pub user: Arc<WSUser>,
}
#[derive(Debug, Message, Clone)]
#[rtype(result = "()")]
pub struct ClientMessage {
pub user: Arc<WSUser>,
pub collab_msg: CollabMessage,
}
#[derive(Debug, Message, Clone)]
#[rtype(result = "()")]
pub struct ServerMessage {
pub collab_msg: CollabMessage,
}

View file

@ -0,0 +1,8 @@
#[derive(Debug, thiserror::Error)]
pub enum WSError {
#[error(transparent)]
Persistence(#[from] collab_persistence::error::PersistenceError),
#[error("Internal failure: {0}")]
Internal(#[from] Box<dyn std::error::Error + Send + Sync>),
}

View file

@ -0,0 +1,7 @@
mod client;
pub mod entities;
mod error;
mod server;
pub use client::*;
pub use server::*;

View file

@ -0,0 +1,190 @@
use crate::entities::{ClientMessage, Connect, Disconnect, WSUser};
use crate::error::WSError;
use crate::ClientSink;
use actix::{Actor, Context, Handler, ResponseFuture};
use collab::core::collab::MutexCollab;
use collab::core::origin::CollabOrigin;
use collab_persistence::kv::rocks_kv::RocksCollabDB;
use collab_persistence::kv::KVStore;
use collab_plugins::disk_plugin::rocksdb_server::RocksdbServerDiskPlugin;
use collab_sync::server::{
CollabBroadcast, CollabGroup, CollabIDGen, CollabId, NonZeroNodeId, COLLAB_ID_LEN,
};
use dashmap::DashMap;
use parking_lot::{Mutex, RwLock};
use std::collections::HashMap;
use collab_persistence::keys::make_collab_id_key;
use collab_sync::msg::CollabMessage;
use std::sync::Arc;
use tokio::sync::mpsc::Sender;
use tokio_stream::wrappers::ReceiverStream;
#[derive(Clone)]
pub struct CollabServer {
db: Arc<RocksCollabDB>,
/// Generate collab_id for new collab object
collab_id_gen: Arc<Mutex<CollabIDGen>>,
/// Memory cache for fast lookup of collab_id from object_id
collab_id_by_object_id: Arc<DashMap<String, CollabId>>,
collab_groups: Arc<RwLock<HashMap<CollabId, CollabGroup>>>,
client_streams: Arc<RwLock<HashMap<Arc<WSUser>, ClientStream>>>,
}
impl CollabServer {
pub fn new(db: Arc<RocksCollabDB>) -> Result<Self, WSError> {
let collab_id_gen = Arc::new(Mutex::new(CollabIDGen::new(NonZeroNodeId(1))));
let collab_id_by_object_id = Arc::new(DashMap::new());
Ok(Self {
db,
collab_id_gen,
collab_id_by_object_id,
collab_groups: Default::default(),
client_streams: Default::default(),
})
}
fn create_collab_id(&self, object_id: &str) -> Result<CollabId, WSError> {
let collab_id = self.collab_id_gen.lock().next_id();
let collab_key = make_collab_id_key(object_id.as_ref());
self.db.with_write_txn(|w_txn| {
w_txn.insert(collab_key.as_ref(), collab_id.to_be_bytes())?;
Ok(())
})?;
Ok(collab_id)
}
fn get_collab_id(&self, object_id: &str) -> Option<CollabId> {
let collab_key = make_collab_id_key(object_id.as_ref());
let read_txn = self.db.read_txn();
let value = read_txn.get(collab_key.as_ref()).ok()??;
let mut bytes = [0; COLLAB_ID_LEN];
bytes[0..COLLAB_ID_LEN].copy_from_slice(value.as_ref());
Some(CollabId::from_be_bytes(bytes))
}
fn get_or_create_collab_id(&self, object_id: &str) -> Result<CollabId, WSError> {
let collab_id = self.get_collab_id(object_id);
if let Some(collab_id) = collab_id {
self.create_group_if_need(collab_id, object_id);
Ok(collab_id)
} else {
let collab_id = self.create_collab_id(object_id)?;
self
.collab_id_by_object_id
.insert(object_id.to_string(), collab_id);
self.create_group_if_need(collab_id, object_id);
Ok(collab_id)
}
}
fn create_group_if_need(&self, collab_id: CollabId, object_id: &str) {
if self.collab_groups.read().contains_key(&collab_id) {
return;
}
let collab = MutexCollab::new(CollabOrigin::Empty, object_id, vec![]);
let plugin = RocksdbServerDiskPlugin::new(collab_id, self.db.clone()).unwrap();
collab.lock().add_plugin(Arc::new(plugin));
collab.initial();
let broadcast = CollabBroadcast::new(object_id, collab.clone(), 10);
let group = CollabGroup {
collab,
broadcast,
subscribers: Default::default(),
};
self.collab_groups.write().insert(collab_id, group);
}
}
impl Actor for CollabServer {
type Context = Context<Self>;
}
impl Handler<Connect> for CollabServer {
type Result = Result<(), WSError>;
fn handle(&mut self, msg: Connect, _ctx: &mut Context<Self>) -> Self::Result {
let (stream_tx, rx) = tokio::sync::mpsc::channel(100);
let stream = ClientStream::new(ClientSink(msg.socket), ReceiverStream::new(rx), stream_tx);
self.client_streams.write().insert(msg.user, stream);
Ok(())
}
}
impl Handler<Disconnect> for CollabServer {
type Result = Result<(), WSError>;
fn handle(&mut self, msg: Disconnect, _: &mut Context<Self>) -> Self::Result {
self.client_streams.write().remove(&msg.user);
Ok(())
}
}
impl Handler<ClientMessage> for CollabServer {
type Result = ResponseFuture<()>;
fn handle(&mut self, msg: ClientMessage, _ctx: &mut Context<Self>) -> Self::Result {
let object_id = msg.collab_msg.object_id();
if let Ok(collab_id) = self.get_or_create_collab_id(object_id) {
if let Some(collab_group) = self.collab_groups.write().get_mut(&collab_id) {
if let Some(stream) = self.client_streams.write().get_mut(&msg.user) {
if let Some((sink, stream)) = stream.split() {
let origin = match msg.collab_msg.origin() {
None => CollabOrigin::Empty,
Some(client) => client.clone(),
};
let sub = collab_group
.broadcast
.subscribe(origin.clone(), sink, stream);
collab_group.subscribers.insert(origin, sub);
}
}
}
let client_streams = self.client_streams.clone();
Box::pin(async move {
if let Some(client_stream) = client_streams.read().get(&msg.user) {
let _ = client_stream.stream_tx.send(Ok(msg.collab_msg)).await;
}
})
} else {
Box::pin(async move {})
}
}
}
impl actix::Supervised for CollabServer {
fn restarting(&mut self, _ctx: &mut Context<CollabServer>) {
tracing::warn!("restarting");
}
}
pub struct ClientStream {
sink: Option<ClientSink>,
stream: Option<ReceiverStream<Result<CollabMessage, WSError>>>,
stream_tx: Sender<Result<CollabMessage, WSError>>,
}
impl ClientStream {
pub fn new(
sink: ClientSink,
stream: ReceiverStream<Result<CollabMessage, WSError>>,
stream_tx: Sender<Result<CollabMessage, WSError>>,
) -> Self {
Self {
sink: Some(sink),
stream: Some(stream),
stream_tx,
}
}
pub fn split(&mut self) -> Option<(ClientSink, ReceiverStream<Result<CollabMessage, WSError>>)> {
let sink = self.sink.take()?;
let stream = self.stream.take()?;
Some((sink, stream))
}
}

12
rustfmt.toml Normal file
View file

@ -0,0 +1,12 @@
# https://rust-lang.github.io/rustfmt/?version=master&search=
max_width = 100
tab_spaces = 2
newline_style = "Auto"
match_block_trailing_comma = true
use_field_init_shorthand = true
use_try_shorthand = true
reorder_imports = true
reorder_modules = true
remove_nested_parens = true
merge_derives = true
edition = "2021"

View file

@ -1,6 +1,6 @@
use crate::component::auth::{ use crate::component::auth::{
change_password, logged_user_from_request, login, logout, register, ChangePasswordRequest, change_password, logged_user_from_request, login, logout, register, ChangePasswordRequest,
InputParamsError, LoginRequest, RegisterRequest, InputParamsError, LoginRequest, RegisterRequest,
}; };
use crate::component::token_state::SessionToken; use crate::component::token_state::SessionToken;
use crate::domain::{UserEmail, UserName, UserPassword}; use crate::domain::{UserEmail, UserName, UserPassword};
@ -11,83 +11,83 @@ use actix_web::{web, HttpResponse, Scope};
use actix_web::{HttpRequest, Result}; use actix_web::{HttpRequest, Result};
pub fn user_scope() -> Scope { pub fn user_scope() -> Scope {
web::scope("/api/user") web::scope("/api/user")
.service(web::resource("/login").route(web::post().to(login_handler))) .service(web::resource("/login").route(web::post().to(login_handler)))
.service(web::resource("/logout").route(web::get().to(logout_handler))) .service(web::resource("/logout").route(web::get().to(logout_handler)))
.service(web::resource("/register").route(web::post().to(register_handler))) .service(web::resource("/register").route(web::post().to(register_handler)))
.service(web::resource("/password").route(web::post().to(change_password_handler))) .service(web::resource("/password").route(web::post().to(change_password_handler)))
} }
async fn login_handler( async fn login_handler(
req: Json<LoginRequest>, req: Json<LoginRequest>,
state: Data<State>, state: Data<State>,
session: SessionToken, session: SessionToken,
) -> Result<HttpResponse> { ) -> Result<HttpResponse> {
let req = req.into_inner(); let req = req.into_inner();
let email = UserEmail::parse(req.email) let email = UserEmail::parse(req.email)
.map_err(InputParamsError::InvalidEmail)? .map_err(InputParamsError::InvalidEmail)?
.0; .0;
let password = UserPassword::parse(req.password) let password = UserPassword::parse(req.password)
.map_err(|_| InputParamsError::InvalidPassword)? .map_err(|_| InputParamsError::InvalidPassword)?
.0; .0;
let (resp, token) = login(email, password, &state).await?; let (resp, token) = login(email, password, &state).await?;
// Renews the session key, assigning existing session state to new key. // Renews the session key, assigning existing session state to new key.
session.renew(); session.renew();
if let Err(err) = session.insert_token(token) { if let Err(err) = session.insert_token(token) {
// It needs to navigate to login page in web application // It needs to navigate to login page in web application
tracing::error!("Insert session failed: {:?}", err); tracing::error!("Insert session failed: {:?}", err);
} }
Ok(HttpResponse::Ok().json(resp)) Ok(HttpResponse::Ok().json(resp))
} }
async fn logout_handler(req: HttpRequest, state: Data<State>) -> Result<HttpResponse> { async fn logout_handler(req: HttpRequest, state: Data<State>) -> Result<HttpResponse> {
let logged_user = logged_user_from_request(&req, &state.config.application.server_key)?; let logged_user = logged_user_from_request(&req, &state.config.application.server_key)?;
logout(logged_user, state.user.clone()).await; logout(logged_user, state.user.clone()).await;
Ok(HttpResponse::Ok().finish()) Ok(HttpResponse::Ok().finish())
} }
#[tracing::instrument(level = "debug", skip(state))] #[tracing::instrument(level = "debug", skip(state))]
async fn register_handler(req: Json<RegisterRequest>, state: Data<State>) -> Result<HttpResponse> { async fn register_handler(req: Json<RegisterRequest>, state: Data<State>) -> Result<HttpResponse> {
let req = req.into_inner(); let req = req.into_inner();
let name = UserName::parse(req.name) let name = UserName::parse(req.name)
.map_err(InputParamsError::InvalidName)? .map_err(InputParamsError::InvalidName)?
.0; .0;
let email = UserEmail::parse(req.email) let email = UserEmail::parse(req.email)
.map_err(InputParamsError::InvalidEmail)? .map_err(InputParamsError::InvalidEmail)?
.0; .0;
let password = UserPassword::parse(req.password) let password = UserPassword::parse(req.password)
.map_err(|_| InputParamsError::InvalidPassword)? .map_err(|_| InputParamsError::InvalidPassword)?
.0; .0;
let resp = register(name, email, password, &state).await?; let resp = register(name, email, password, &state).await?;
Ok(HttpResponse::Ok().json(resp)) Ok(HttpResponse::Ok().json(resp))
} }
async fn change_password_handler( async fn change_password_handler(
req: HttpRequest, req: HttpRequest,
payload: Json<ChangePasswordRequest>, payload: Json<ChangePasswordRequest>,
// session: SessionToken, // session: SessionToken,
state: Data<State>, state: Data<State>,
) -> Result<HttpResponse> { ) -> Result<HttpResponse> {
let logged_user = logged_user_from_request(&req, &state.config.application.server_key)?; let logged_user = logged_user_from_request(&req, &state.config.application.server_key)?;
let payload = payload.into_inner(); let payload = payload.into_inner();
if payload.new_password != payload.new_password_confirm { if payload.new_password != payload.new_password_confirm {
return Err(InputParamsError::PasswordNotMatch.into()); return Err(InputParamsError::PasswordNotMatch.into());
} }
let new_password = UserPassword::parse(payload.new_password) let new_password = UserPassword::parse(payload.new_password)
.map_err(|_| InputParamsError::InvalidPassword)? .map_err(|_| InputParamsError::InvalidPassword)?
.0; .0;
change_password( change_password(
state.pg_pool.clone(), state.pg_pool.clone(),
logged_user.clone(), logged_user.clone(),
payload.current_password, payload.current_password,
new_password, new_password,
) )
.await?; .await?;
Ok(HttpResponse::Ok().finish()) Ok(HttpResponse::Ok().finish())
} }

View file

@ -1,32 +1,40 @@
use crate::component::auth::LoggedUser; use crate::component::auth::LoggedUser;
use crate::component::ws::{MessageReceivers, WSClient, WSServer};
use crate::state::State; use crate::state::State;
use actix::Addr; use actix::Addr;
use actix_web::web::{Data, Path, Payload}; use actix_web::web::{Data, Path, Payload};
use actix_web::{get, web, HttpRequest, HttpResponse, Result, Scope}; use actix_web::{get, web, HttpRequest, HttpResponse, Result, Scope};
use actix_web_actors::ws; use actix_web_actors::ws;
use secrecy::Secret;
use websocket::entities::WSUser;
use websocket::{CollabServer, CollabSession};
pub fn ws_scope() -> Scope { pub fn ws_scope() -> Scope {
web::scope("/ws").service(establish_ws_connection) web::scope("/ws").service(establish_ws_connection)
} }
#[get("/{token}")] #[get("/{token}")]
pub async fn establish_ws_connection( pub async fn establish_ws_connection(
request: HttpRequest, request: HttpRequest,
payload: Payload, payload: Payload,
token: Path<String>, token: Path<String>,
state: Data<State>, state: Data<State>,
server: Data<Addr<WSServer>>, server: Data<Addr<CollabServer>>,
msg_receivers: Data<MessageReceivers>,
) -> Result<HttpResponse> { ) -> Result<HttpResponse> {
tracing::info!("establish_ws_connection"); let user = LoggedUser::from_token(&state.config.application.server_key, token.as_str())?;
let user = LoggedUser::from_token(&state.config.application.server_key, token.as_str())?; let client = CollabSession::new(user.into(), server.get_ref().clone());
let client = WSClient::new(user, server.get_ref().clone(), msg_receivers); match ws::start(client, &request, payload) {
match ws::start(client, &request, payload) { Ok(response) => Ok(response),
Ok(response) => Ok(response), Err(e) => {
Err(e) => { tracing::error!("ws connection error: {:?}", e);
tracing::error!("ws connection error: {:?}", e); Err(e)
Err(e) },
} }
} }
impl From<LoggedUser> for WSUser {
fn from(user: LoggedUser) -> Self {
Self {
user_id: Secret::new(user.expose_secret().to_string()),
}
}
} }

View file

@ -10,6 +10,9 @@ use actix_session::SessionMiddleware;
use actix_web::cookie::Key; use actix_web::cookie::Key;
use actix_web::{dev::Server, web, web::Data, App, HttpServer}; use actix_web::{dev::Server, web, web::Data, App, HttpServer};
use actix::Actor;
use collab_persistence::kv::rocks_kv::RocksCollabDB;
use openssl::ssl::{SslAcceptor, SslAcceptorBuilder, SslFiletype, SslMethod}; use openssl::ssl::{SslAcceptor, SslAcceptorBuilder, SslFiletype, SslMethod};
use openssl::x509::X509; use openssl::x509::X509;
use secrecy::{ExposeSecret, Secret}; use secrecy::{ExposeSecret, Secret};
@ -18,121 +21,136 @@ use sqlx::{postgres::PgPoolOptions, PgPool};
use std::net::TcpListener; use std::net::TcpListener;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tracing_actix_web::TracingLogger; use tracing_actix_web::TracingLogger;
use websocket::CollabServer;
pub struct Application { pub struct Application {
port: u16, port: u16,
server: Server, server: Server,
} }
impl Application { impl Application {
pub async fn build(config: Config, state: State) -> Result<Self, anyhow::Error> { pub async fn build(config: Config, state: State) -> Result<Self, anyhow::Error> {
let address = format!("{}:{}", config.application.host, config.application.port); let address = format!("{}:{}", config.application.host, config.application.port);
let listener = TcpListener::bind(&address)?; let listener = TcpListener::bind(&address)?;
let port = listener.local_addr().unwrap().port(); let port = listener.local_addr().unwrap().port();
let server = run(listener, state, config).await?; let server = run(listener, state, config).await?;
Ok(Self { port, server }) Ok(Self { port, server })
} }
pub async fn run_until_stopped(self) -> Result<(), std::io::Error> { pub async fn run_until_stopped(self) -> Result<(), std::io::Error> {
self.server.await self.server.await
} }
pub fn port(&self) -> u16 { pub fn port(&self) -> u16 {
self.port self.port
} }
} }
pub async fn run( pub async fn run(
listener: TcpListener, listener: TcpListener,
state: State, state: State,
config: Config, config: Config,
) -> Result<Server, anyhow::Error> { ) -> Result<Server, anyhow::Error> {
let redis_store = RedisSessionStore::new(config.redis_uri.expose_secret()) let redis_store = RedisSessionStore::new(config.redis_uri.expose_secret())
.await .await
.map_err(|e| { .map_err(|e| {
anyhow::anyhow!( anyhow::anyhow!(
"Failed to connect to Redis at {:?}: {:?}", "Failed to connect to Redis at {:?}: {:?}",
config.redis_uri, config.redis_uri,
e e
) )
})?; })?;
let pair = get_certificate_and_server_key(&config); let pair = get_certificate_and_server_key(&config);
let key = pair let key = pair
.as_ref() .as_ref()
.map(|(_, server_key)| Key::from(server_key.expose_secret().as_bytes())) .map(|(_, server_key)| Key::from(server_key.expose_secret().as_bytes()))
.unwrap_or_else(Key::generate); .unwrap_or_else(Key::generate);
let mut server = HttpServer::new(move || {
App::new() let collab_server = CollabServer::new(state.rocksdb.clone()).unwrap().start();
// Session middleware let mut server = HttpServer::new(move || {
App::new()
.wrap( .wrap(
SessionMiddleware::builder(redis_store.clone(), key.clone()) SessionMiddleware::builder(redis_store.clone(), key.clone())
.cookie_name(HEADER_TOKEN.to_string()) .cookie_name(HEADER_TOKEN.to_string())
.build(), .build(),
) )
// .wrap(ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, add_error_header))
.wrap(IdentityMiddleware::default()) .wrap(IdentityMiddleware::default())
.wrap(default_cors()) .wrap(default_cors())
.wrap(TracingLogger::default()) .wrap(TracingLogger::default())
.app_data(web::JsonConfig::default().limit(4096)) .app_data(web::JsonConfig::default().limit(4096))
.service(user_scope()) .service(user_scope())
.service(ws_scope()) .service(ws_scope())
.app_data(Data::new(collab_server.clone()))
.app_data(Data::new(state.clone())) .app_data(Data::new(state.clone()))
}); });
server = match pair { server = match pair {
None => server.listen(listener)?, None => server.listen(listener)?,
Some((certificate, _)) => { Some((certificate, _)) => {
server.listen_openssl(listener, make_ssl_acceptor_builder(certificate))? server.listen_openssl(listener, make_ssl_acceptor_builder(certificate))?
} },
}; };
Ok(server.run()) Ok(server.run())
} }
fn get_certificate_and_server_key(config: &Config) -> Option<(Secret<String>, Secret<String>)> { fn get_certificate_and_server_key(config: &Config) -> Option<(Secret<String>, Secret<String>)> {
let tls_config = config.application.tls_config.as_ref()?; let tls_config = config.application.tls_config.as_ref()?;
match tls_config { match tls_config {
TlsConfig::NoTls => None, TlsConfig::NoTls => None,
TlsConfig::SelfSigned => Some(create_self_signed_certificate().unwrap()), TlsConfig::SelfSigned => Some(create_self_signed_certificate().unwrap()),
} }
} }
pub async fn init_state(config: &Config) -> State { pub async fn init_state(config: &Config) -> State {
let pg_pool = get_connection_pool(&config.database) let pg_pool = get_connection_pool(&config.database)
.await .await
.unwrap_or_else(|_| panic!("Failed to connect to Postgres at {:?}.", config.database)); .unwrap_or_else(|_| panic!("Failed to connect to Postgres at {:?}.", config.database));
State { std::fs::create_dir_all(config.application.rocksdb_db_dir()).expect("create rocksdb db dir");
pg_pool, let rocksdb = Arc::new(RocksCollabDB::open(config.application.rocksdb_db_dir()).unwrap());
config: Arc::new(config.clone()), State {
user: Arc::new(Default::default()), pg_pool,
id_gen: Arc::new(RwLock::new(Snowflake::new(1))), rocksdb,
} config: Arc::new(config.clone()),
user: Arc::new(Default::default()),
id_gen: Arc::new(RwLock::new(Snowflake::new(1))),
}
} }
pub async fn get_connection_pool(setting: &DatabaseSetting) -> Result<PgPool, sqlx::Error> { pub async fn get_connection_pool(setting: &DatabaseSetting) -> Result<PgPool, sqlx::Error> {
PgPoolOptions::new() PgPoolOptions::new()
.acquire_timeout(std::time::Duration::from_secs(5)) .acquire_timeout(std::time::Duration::from_secs(5))
.connect_with(setting.with_db()) .connect_with(setting.with_db())
.await .await
} }
fn make_ssl_acceptor_builder(certificate: Secret<String>) -> SslAcceptorBuilder { fn make_ssl_acceptor_builder(certificate: Secret<String>) -> SslAcceptorBuilder {
let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
let x509_cert = X509::from_pem(certificate.expose_secret().as_bytes()).unwrap(); let x509_cert = X509::from_pem(certificate.expose_secret().as_bytes()).unwrap();
builder.set_certificate(&x509_cert).unwrap(); builder.set_certificate(&x509_cert).unwrap();
builder builder
.set_private_key_file("./cert/key.pem", SslFiletype::PEM) .set_private_key_file("./cert/key.pem", SslFiletype::PEM)
.unwrap(); .unwrap();
builder builder
.set_certificate_chain_file("./cert/cert.pem") .set_certificate_chain_file("./cert/cert.pem")
.unwrap(); .unwrap();
builder builder
.set_min_proto_version(Some(openssl::ssl::SslVersion::TLS1_2)) .set_min_proto_version(Some(openssl::ssl::SslVersion::TLS1_2))
.unwrap(); .unwrap();
builder builder
.set_max_proto_version(Some(openssl::ssl::SslVersion::TLS1_3)) .set_max_proto_version(Some(openssl::ssl::SslVersion::TLS1_3))
.unwrap(); .unwrap();
builder builder
} }
// fn add_error_header<B>(
// res: dev::ServiceResponse<B>,
// ) -> Result<ErrorHandlerResponse<B>, actix_web::Error> {
// tracing::error!("{:?}", res.request());
// Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
// }

View file

@ -5,76 +5,76 @@ use thiserror::Error;
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum AuthError { pub enum AuthError {
#[error("Credentials is invalid")] #[error("Credentials is invalid")]
InvalidCredentials(#[source] anyhow::Error), InvalidCredentials(#[source] anyhow::Error),
#[error("User is not exist")] #[error("User is not exist")]
UserNotExist(#[source] anyhow::Error), UserNotExist(#[source] anyhow::Error),
#[error("{} is already used", email)] #[error("{} is already used", email)]
UserAlreadyExist { email: String }, UserAlreadyExist { email: String },
#[error("Invalid password")] #[error("Invalid password")]
InvalidPassword, InvalidPassword,
#[error("User is unauthorized")] #[error("User is unauthorized")]
Unauthorized, Unauthorized,
#[error("User internal error")] #[error("User internal error")]
InternalError(#[from] anyhow::Error), InternalError(#[from] anyhow::Error),
#[error("Parser uuid failed: {}", err)] #[error("Parser uuid failed: {}", err)]
InvalidUuid { err: String }, InvalidUuid { err: String },
} }
pub fn internal_error(error: anyhow::Error) -> AuthError { pub fn internal_error(error: anyhow::Error) -> AuthError {
AuthError::InternalError(error) AuthError::InternalError(error)
} }
impl actix_web::error::ResponseError for AuthError { impl actix_web::error::ResponseError for AuthError {
fn status_code(&self) -> StatusCode { fn status_code(&self) -> StatusCode {
match *self { match *self {
AuthError::InvalidCredentials(_) => StatusCode::UNAUTHORIZED, AuthError::InvalidCredentials(_) => StatusCode::UNAUTHORIZED,
AuthError::UserNotExist(_) => StatusCode::UNAUTHORIZED, AuthError::UserNotExist(_) => StatusCode::UNAUTHORIZED,
AuthError::UserAlreadyExist { .. } => StatusCode::BAD_REQUEST, AuthError::UserAlreadyExist { .. } => StatusCode::BAD_REQUEST,
AuthError::InvalidPassword => StatusCode::UNAUTHORIZED, AuthError::InvalidPassword => StatusCode::UNAUTHORIZED,
AuthError::Unauthorized => StatusCode::UNAUTHORIZED, AuthError::Unauthorized => StatusCode::UNAUTHORIZED,
AuthError::InternalError(_) => StatusCode::INTERNAL_SERVER_ERROR, AuthError::InternalError(_) => StatusCode::INTERNAL_SERVER_ERROR,
AuthError::InvalidUuid { .. } => StatusCode::UNAUTHORIZED, AuthError::InvalidUuid { .. } => StatusCode::UNAUTHORIZED,
}
} }
}
fn error_response(&self) -> HttpResponse { fn error_response(&self) -> HttpResponse {
HttpResponse::build(self.status_code()).body(self.to_string()) HttpResponse::build(self.status_code()).body(self.to_string())
} }
} }
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum InputParamsError { pub enum InputParamsError {
#[error("Invalid name")] #[error("Invalid name")]
InvalidName(String), InvalidName(String),
#[error("Invalid email format")] #[error("Invalid email format")]
InvalidEmail(String), InvalidEmail(String),
#[error("Invalid password")] #[error("Invalid password")]
InvalidPassword, InvalidPassword,
#[error("You entered two different new passwords")] #[error("You entered two different new passwords")]
PasswordNotMatch, PasswordNotMatch,
} }
impl actix_web::error::ResponseError for InputParamsError { impl actix_web::error::ResponseError for InputParamsError {
fn status_code(&self) -> StatusCode { fn status_code(&self) -> StatusCode {
match *self { match *self {
InputParamsError::InvalidName(_) => StatusCode::BAD_REQUEST, InputParamsError::InvalidName(_) => StatusCode::BAD_REQUEST,
InputParamsError::InvalidEmail(_) => StatusCode::BAD_REQUEST, InputParamsError::InvalidEmail(_) => StatusCode::BAD_REQUEST,
InputParamsError::InvalidPassword => StatusCode::BAD_REQUEST, InputParamsError::InvalidPassword => StatusCode::BAD_REQUEST,
InputParamsError::PasswordNotMatch => StatusCode::BAD_REQUEST, InputParamsError::PasswordNotMatch => StatusCode::BAD_REQUEST,
}
} }
}
fn error_response(&self) -> HttpResponse { fn error_response(&self) -> HttpResponse {
HttpResponse::build(self.status_code()).body(self.to_string()) HttpResponse::build(self.status_code()).body(self.to_string())
} }
} }

View file

@ -8,84 +8,85 @@ use secrecy::{ExposeSecret, Secret};
use sqlx::PgPool; use sqlx::PgPool;
pub struct Credentials { pub struct Credentials {
pub email: String, pub email: String,
pub password: Secret<String>, pub password: Secret<String>,
} }
#[tracing::instrument(level = "debug", skip(credentials, pool))] #[tracing::instrument(level = "debug", skip(credentials, pool))]
pub async fn validate_credentials( pub async fn validate_credentials(
credentials: Credentials, credentials: Credentials,
pool: &PgPool, pool: &PgPool,
) -> Result<i64, AuthError> { ) -> Result<i64, AuthError> {
let mut uid = None; let mut uid = None;
let mut expected_hash_password = Secret::new( let mut expected_hash_password = Secret::new(
"$argon2id$v=19$m=15000,t=2,p=1$\ "$argon2id$v=19$m=15000,t=2,p=1$\
gZiV/M1gPc22ElAH/Jh1Hw$\ gZiV/M1gPc22ElAH/Jh1Hw$\
CWOrkoo7oJBQ/iyh7uJ0LO2aLEfrHwTWllSAxT0zRno" CWOrkoo7oJBQ/iyh7uJ0LO2aLEfrHwTWllSAxT0zRno"
.to_string(), .to_string(),
); );
if let Some((stored_uid, stored_hash_password)) = if let Some((stored_uid, stored_hash_password)) =
get_stored_credentials(&credentials.email, pool).await? get_stored_credentials(&credentials.email, pool).await?
{ {
uid = Some(stored_uid); uid = Some(stored_uid);
expected_hash_password = stored_hash_password; expected_hash_password = stored_hash_password;
} }
spawn_blocking_with_tracing(move || { spawn_blocking_with_tracing(move || {
verify_password_hash(expected_hash_password, credentials.password) verify_password_hash(expected_hash_password, credentials.password)
}) })
.await .await
.context("Failed to spawn blocking task.")??; .context("Failed to spawn blocking task.")??;
uid.ok_or_else(|| anyhow::anyhow!("Unknown email.")) uid
.map_err(AuthError::InvalidCredentials) .ok_or_else(|| anyhow::anyhow!("Unknown email."))
.map_err(AuthError::InvalidCredentials)
} }
pub fn compute_hash_password(password: &[u8]) -> Result<Secret<String>, anyhow::Error> { pub fn compute_hash_password(password: &[u8]) -> Result<Secret<String>, anyhow::Error> {
let salt = SaltString::generate(&mut rand::thread_rng()); let salt = SaltString::generate(&mut rand::thread_rng());
let password = Argon2::new( let password = Argon2::new(
Algorithm::Argon2id, Algorithm::Argon2id,
Version::V0x13, Version::V0x13,
Params::new(15000, 2, 1, None).unwrap(), Params::new(15000, 2, 1, None).unwrap(),
) )
.hash_password(password, &salt)? .hash_password(password, &salt)?
.to_string(); .to_string();
Ok(Secret::new(password)) Ok(Secret::new(password))
} }
#[tracing::instrument(level = "debug", skip(email, pool))] #[tracing::instrument(level = "debug", skip(email, pool))]
async fn get_stored_credentials( async fn get_stored_credentials(
email: &str, email: &str,
pool: &PgPool, pool: &PgPool,
) -> Result<Option<(i64, Secret<String>)>, anyhow::Error> { ) -> Result<Option<(i64, Secret<String>)>, anyhow::Error> {
let row = sqlx::query!( let row = sqlx::query!(
r#" r#"
SELECT uid, password SELECT uid, password
FROM users FROM users
WHERE email = $1 WHERE email = $1
"#, "#,
email, email,
) )
.fetch_optional(pool) .fetch_optional(pool)
.await .await
.context("Failed to performed a query to retrieve stored credentials.")? .context("Failed to performed a query to retrieve stored credentials.")?
.map(|row| (row.uid, Secret::new(row.password))); .map(|row| (row.uid, Secret::new(row.password)));
Ok(row) Ok(row)
} }
fn verify_password_hash( fn verify_password_hash(
expected_password_hash: Secret<String>, expected_password_hash: Secret<String>,
password_candidate: Secret<String>, password_candidate: Secret<String>,
) -> Result<(), AuthError> { ) -> Result<(), AuthError> {
let expected_hash_password = PasswordHash::new(expected_password_hash.expose_secret()) let expected_hash_password = PasswordHash::new(expected_password_hash.expose_secret())
.context("Failed to parse hash in PHC string format.")?; .context("Failed to parse hash in PHC string format.")?;
Argon2::default() Argon2::default()
.verify_password( .verify_password(
password_candidate.expose_secret().as_bytes(), password_candidate.expose_secret().as_bytes(),
&expected_hash_password, &expected_hash_password,
) )
.context("Invalid password.") .context("Invalid password.")
.map_err(|_| AuthError::InvalidPassword) .map_err(|_| AuthError::InvalidPassword)
} }

View file

@ -1,5 +1,5 @@
use crate::component::auth::{ use crate::component::auth::{
compute_hash_password, internal_error, validate_credentials, AuthError, Credentials, compute_hash_password, internal_error, validate_credentials, AuthError, Credentials,
}; };
use crate::config::env::domain; use crate::config::env::domain;
use crate::state::{State, UserCache}; use crate::state::{State, UserCache};
@ -18,234 +18,234 @@ use token::{create_token, parse_token, TokenError};
use tokio::sync::RwLock; use tokio::sync::RwLock;
pub async fn login( pub async fn login(
email: String, email: String,
password: String, password: String,
state: &State, state: &State,
) -> Result<(LoginResponse, Secret<Token>), AuthError> { ) -> Result<(LoginResponse, Secret<Token>), AuthError> {
let credentials = Credentials { let credentials = Credentials {
email, email,
password: Secret::new(password), password: Secret::new(password),
}; };
let server_key = &state.config.application.server_key; let server_key = &state.config.application.server_key;
match validate_credentials(credentials, &state.pg_pool).await { match validate_credentials(credentials, &state.pg_pool).await {
Ok(uid) => { Ok(uid) => {
let token = Token::create_token(uid, server_key)?; let token = Token::create_token(uid, server_key)?;
let logged_user = LoggedUser::new(uid); let logged_user = LoggedUser::new(uid);
state.user.write().await.authorized(logged_user); state.user.write().await.authorized(logged_user);
Ok(( Ok((
LoginResponse { LoginResponse {
token: token.0.clone(), token: token.0.clone(),
uid: uid.to_string(), uid: uid.to_string(),
}, },
Secret::new(token), Secret::new(token),
)) ))
} },
Err(err) => Err(err), Err(err) => Err(err),
} }
} }
pub async fn logout(logged_user: LoggedUser, cache: Arc<RwLock<UserCache>>) { pub async fn logout(logged_user: LoggedUser, cache: Arc<RwLock<UserCache>>) {
cache.write().await.unauthorized(logged_user); cache.write().await.unauthorized(logged_user);
} }
pub async fn register( pub async fn register(
username: String, username: String,
email: String, email: String,
password: String, password: String,
state: &State, state: &State,
) -> Result<RegisterResponse, AuthError> { ) -> Result<RegisterResponse, AuthError> {
let pg_pool = state.pg_pool.clone(); let pg_pool = state.pg_pool.clone();
let server_key = &state.config.application.server_key; let server_key = &state.config.application.server_key;
let mut transaction = pg_pool let mut transaction = pg_pool
.begin() .begin()
.await .await
.context("Failed to acquire a Postgres connection to register user") .context("Failed to acquire a Postgres connection to register user")
.map_err(internal_error)?; .map_err(internal_error)?;
if is_email_exist(&mut transaction, email.as_ref()) if is_email_exist(&mut transaction, email.as_ref())
.await .await
.map_err(internal_error)? .map_err(internal_error)?
{ {
return Err(AuthError::UserAlreadyExist { email }); return Err(AuthError::UserAlreadyExist { email });
} }
let uid = state.id_gen.write().await.next_id(); let uid = state.id_gen.write().await.next_id();
let token = Token::create_token(uid, server_key)?; let token = Token::create_token(uid, server_key)?;
let password = compute_hash_password(password.as_bytes()).map_err(internal_error)?; let password = compute_hash_password(password.as_bytes()).map_err(internal_error)?;
let _ = sqlx::query!( let _ = sqlx::query!(
r#" r#"
INSERT INTO users (uid, email, username, create_time, password) INSERT INTO users (uid, email, username, create_time, password)
VALUES ($1, $2, $3, $4, $5) VALUES ($1, $2, $3, $4, $5)
"#, "#,
uid, uid,
email, email,
username, username,
Utc::now(), Utc::now(),
password.expose_secret(), password.expose_secret(),
) )
.execute(&mut transaction) .execute(&mut transaction)
.await
.context("Save user to disk failed")
.map_err(internal_error)?;
transaction
.commit()
.await .await
.context("Save user to disk failed") .context("Failed to commit SQL transaction to register user.")
.map_err(internal_error)?; .map_err(internal_error)?;
transaction let logged_user = LoggedUser::new(uid);
.commit() state.user.write().await.authorized(logged_user);
.await
.context("Failed to commit SQL transaction to register user.")
.map_err(internal_error)?;
let logged_user = LoggedUser::new(uid); Ok(RegisterResponse {
state.user.write().await.authorized(logged_user); token: token.0.clone(),
})
Ok(RegisterResponse {
token: token.0.clone(),
})
} }
pub async fn change_password( pub async fn change_password(
pg_pool: PgPool, pg_pool: PgPool,
logged_user: LoggedUser, logged_user: LoggedUser,
current_password: String, current_password: String,
new_password: String, new_password: String,
) -> Result<(), AuthError> { ) -> Result<(), AuthError> {
let mut transaction = pg_pool let mut transaction = pg_pool
.begin() .begin()
.await .await
.context("Failed to acquire a Postgres connection to change password") .context("Failed to acquire a Postgres connection to change password")
.map_err(internal_error)?; .map_err(internal_error)?;
let email = get_user_email(*logged_user.expose_secret(), &mut transaction).await?; let email = get_user_email(*logged_user.expose_secret(), &mut transaction).await?;
// check password // check password
let credentials = Credentials { let credentials = Credentials {
email, email,
password: Secret::new(current_password), password: Secret::new(current_password),
}; };
let _ = validate_credentials(credentials, &pg_pool).await?; let _ = validate_credentials(credentials, &pg_pool).await?;
// Hash password // Hash password
let new_hash_password = let new_hash_password =
spawn_blocking_with_tracing(move || compute_hash_password(new_password.as_bytes())) spawn_blocking_with_tracing(move || compute_hash_password(new_password.as_bytes()))
.await .await
.context("Failed to hash password")??; .context("Failed to hash password")??;
// Save password to disk // Save password to disk
let sql = "UPDATE users SET password = $1 where uid = $2"; let sql = "UPDATE users SET password = $1 where uid = $2";
let _ = sqlx::query(sql) let _ = sqlx::query(sql)
.bind(new_hash_password.expose_secret()) .bind(new_hash_password.expose_secret())
.bind(logged_user.expose_secret()) .bind(logged_user.expose_secret())
.execute(&mut transaction) .execute(&mut transaction)
.await .await
.context("Failed to change user's password in the database.")?; .context("Failed to change user's password in the database.")?;
transaction transaction
.commit() .commit()
.await .await
.context("Failed to commit SQL transaction to change user's password") .context("Failed to commit SQL transaction to change user's password")
.map_err(internal_error)?; .map_err(internal_error)?;
Ok(()) Ok(())
} }
pub async fn get_user_email( pub async fn get_user_email(
uid: i64, uid: i64,
transaction: &mut Transaction<'_, Postgres>, transaction: &mut Transaction<'_, Postgres>,
) -> Result<String, anyhow::Error> { ) -> Result<String, anyhow::Error> {
let row = sqlx::query!( let row = sqlx::query!(
r#" r#"
SELECT email SELECT email
FROM users FROM users
WHERE uid = $1 WHERE uid = $1
"#, "#,
uid, uid,
) )
.fetch_one(transaction) .fetch_one(transaction)
.await .await
.context("Failed to retrieve the username`")?; .context("Failed to retrieve the username`")?;
Ok(row.email) Ok(row.email)
} }
/// TODO: cache this state in [State] /// TODO: cache this state in [State]
async fn is_email_exist( async fn is_email_exist(
transaction: &mut Transaction<'_, Postgres>, transaction: &mut Transaction<'_, Postgres>,
email: &str, email: &str,
) -> Result<bool, anyhow::Error> { ) -> Result<bool, anyhow::Error> {
let result = sqlx::query(r#"SELECT email FROM users WHERE email = $1"#) let result = sqlx::query(r#"SELECT email FROM users WHERE email = $1"#)
.bind(email) .bind(email)
.fetch_optional(transaction) .fetch_optional(transaction)
.await?; .await?;
Ok(result.is_some()) Ok(result.is_some())
} }
#[derive(Default, Deserialize, Debug)] #[derive(Default, Deserialize, Debug)]
pub struct LoginRequest { pub struct LoginRequest {
pub email: String, pub email: String,
pub password: String, pub password: String,
} }
#[derive(Default, Serialize, Deserialize, Debug)] #[derive(Default, Serialize, Deserialize, Debug)]
pub struct LoginResponse { pub struct LoginResponse {
pub token: String, pub token: String,
pub uid: String, pub uid: String,
} }
#[derive(Default, Deserialize, Debug)] #[derive(Default, Deserialize, Debug)]
pub struct RegisterRequest { pub struct RegisterRequest {
pub email: String, pub email: String,
pub password: String, pub password: String,
pub name: String, pub name: String,
} }
#[derive(Default, Serialize, Deserialize, Debug)] #[derive(Default, Serialize, Deserialize, Debug)]
pub struct RegisterResponse { pub struct RegisterResponse {
pub token: String, pub token: String,
} }
#[derive(Default, Deserialize, Debug)] #[derive(Default, Deserialize, Debug)]
pub struct ChangePasswordRequest { pub struct ChangePasswordRequest {
pub current_password: String, pub current_password: String,
pub new_password: String, pub new_password: String,
pub new_password_confirm: String, pub new_password_confirm: String,
} }
#[derive(Clone, Default)] #[derive(Clone, Default)]
pub struct WrapI64(i64); pub struct SecretI64(i64);
impl Copy for WrapI64 {} impl Copy for SecretI64 {}
impl DefaultIsZeroes for WrapI64 {} impl DefaultIsZeroes for SecretI64 {}
impl DebugSecret for WrapI64 {} impl DebugSecret for SecretI64 {}
impl CloneableSecret for WrapI64 {} impl CloneableSecret for SecretI64 {}
impl std::ops::Deref for WrapI64 { impl std::ops::Deref for SecretI64 {
type Target = i64; type Target = i64;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
&self.0 &self.0
} }
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct LoggedUser(Secret<WrapI64>); pub struct LoggedUser(Secret<SecretI64>);
impl From<Claim> for LoggedUser { impl From<Claim> for LoggedUser {
fn from(c: Claim) -> Self { fn from(c: Claim) -> Self {
Self(Secret::new(WrapI64(c.uid))) Self(Secret::new(SecretI64(c.uid)))
} }
} }
impl LoggedUser { impl LoggedUser {
pub fn new(uid: i64) -> Self { pub fn new(uid: i64) -> Self {
Self(Secret::new(WrapI64(uid))) Self(Secret::new(SecretI64(uid)))
} }
pub fn from_token(server_key: &Secret<String>, token: &str) -> Result<Self, AuthError> { pub fn from_token(server_key: &Secret<String>, token: &str) -> Result<Self, AuthError> {
let user: LoggedUser = Token::decode_token(server_key, token)?.into(); let user: LoggedUser = Token::decode_token(server_key, token)?.into();
Ok(user) Ok(user)
} }
pub fn expose_secret(&self) -> &i64 { pub fn expose_secret(&self) -> &i64 {
self.0.expose_secret() self.0.expose_secret()
} }
} }
pub const HEADER_TOKEN: &str = "token"; pub const HEADER_TOKEN: &str = "token";
@ -253,68 +253,68 @@ pub const EXPIRED_DURATION_DAYS: i64 = 30;
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct Claim { pub struct Claim {
iss: String, iss: String,
uid: i64, uid: i64,
} }
impl Claim { impl Claim {
pub fn with_user_id(uid: i64) -> Self { pub fn with_user_id(uid: i64) -> Self {
Self { iss: domain(), uid } Self { iss: domain(), uid }
} }
} }
#[derive(Clone, Default, Serialize, Deserialize)] #[derive(Clone, Default, Serialize, Deserialize)]
pub struct Token(pub String); pub struct Token(pub String);
impl Zeroize for Token { impl Zeroize for Token {
fn zeroize(&mut self) { fn zeroize(&mut self) {
self.0.zeroize() self.0.zeroize()
} }
} }
impl Token { impl Token {
pub fn create_token(uid: i64, server_key: &Secret<String>) -> Result<Self, AuthError> { pub fn create_token(uid: i64, server_key: &Secret<String>) -> Result<Self, AuthError> {
let claim = Claim::with_user_id(uid); let claim = Claim::with_user_id(uid);
let token = create_token( let token = create_token(
server_key.expose_secret().as_str(), server_key.expose_secret().as_str(),
claim, claim,
Duration::days(EXPIRED_DURATION_DAYS), Duration::days(EXPIRED_DURATION_DAYS),
) )
.map_err(|e| match e { .map_err(|e| match e {
TokenError::Jwt(_) => AuthError::Unauthorized, TokenError::Jwt(_) => AuthError::Unauthorized,
TokenError::Expired => AuthError::Unauthorized, TokenError::Expired => AuthError::Unauthorized,
})?; })?;
Ok(Self(token)) Ok(Self(token))
} }
pub fn decode_token(server_key: &Secret<String>, token: &str) -> Result<Claim, AuthError> { pub fn decode_token(server_key: &Secret<String>, token: &str) -> Result<Claim, AuthError> {
parse_token::<Claim>(server_key.expose_secret().as_str(), token) parse_token::<Claim>(server_key.expose_secret().as_str(), token)
.map_err(|_| AuthError::Unauthorized) .map_err(|_| AuthError::Unauthorized)
} }
} }
pub fn logged_user_from_request( pub fn logged_user_from_request(
request: &HttpRequest, request: &HttpRequest,
server_key: &Secret<String>, server_key: &Secret<String>,
) -> Result<LoggedUser, AuthError> { ) -> Result<LoggedUser, AuthError> {
match request.headers().get(HEADER_TOKEN) { match request.headers().get(HEADER_TOKEN) {
None => Err(AuthError::Unauthorized), None => Err(AuthError::Unauthorized),
Some(header) => match header.to_str() { Some(header) => match header.to_str() {
Ok(token_str) => LoggedUser::from_token(server_key, token_str), Ok(token_str) => LoggedUser::from_token(server_key, token_str),
Err(_) => Err(AuthError::Unauthorized), Err(_) => Err(AuthError::Unauthorized),
}, },
} }
} }
pub fn uid_from_request( pub fn uid_from_request(
request: &HttpRequest, request: &HttpRequest,
server_key: &Secret<String>, server_key: &Secret<String>,
) -> Result<Secret<i64>, AuthError> { ) -> Result<Secret<i64>, AuthError> {
match request.headers().get(HEADER_TOKEN) { match request.headers().get(HEADER_TOKEN) {
Some(header) => match header.to_str() { Some(header) => match header.to_str() {
Ok(val) => Token::decode_token(server_key, val).map(|claim| Secret::new(claim.uid)), Ok(val) => Token::decode_token(server_key, val).map(|claim| Secret::new(claim.uid)),
Err(_) => Err(AuthError::Unauthorized), Err(_) => Err(AuthError::Unauthorized),
}, },
None => Err(AuthError::Unauthorized), None => Err(AuthError::Unauthorized),
} }
} }

View file

@ -1,3 +1,2 @@
pub mod auth; pub mod auth;
pub mod token_state; pub mod token_state;
pub mod ws;

View file

@ -9,30 +9,30 @@ use std::future::{ready, Ready};
pub struct SessionToken(Session); pub struct SessionToken(Session);
impl SessionToken { impl SessionToken {
const TOKEN_ID_KEY: &'static str = "token"; const TOKEN_ID_KEY: &'static str = "token";
pub fn renew(&self) { pub fn renew(&self) {
self.0.renew(); self.0.renew();
} }
pub fn insert_token(&self, token: Secret<Token>) -> Result<(), SessionInsertError> { pub fn insert_token(&self, token: Secret<Token>) -> Result<(), SessionInsertError> {
self.0.insert(Self::TOKEN_ID_KEY, token.expose_secret()) self.0.insert(Self::TOKEN_ID_KEY, token.expose_secret())
} }
pub fn get_token(&self) -> Result<Option<String>, SessionGetError> { pub fn get_token(&self) -> Result<Option<String>, SessionGetError> {
self.0.get(Self::TOKEN_ID_KEY) self.0.get(Self::TOKEN_ID_KEY)
} }
pub fn log_out(self) { pub fn log_out(self) {
self.0.purge() self.0.purge()
} }
} }
impl FromRequest for SessionToken { impl FromRequest for SessionToken {
type Error = <Session as FromRequest>::Error; type Error = <Session as FromRequest>::Error;
type Future = Ready<Result<SessionToken, Self::Error>>; type Future = Ready<Result<SessionToken, Self::Error>>;
fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future { fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future {
ready(Ok(SessionToken(req.get_session()))) ready(Ok(SessionToken(req.get_session())))
} }
} }

View file

@ -1,164 +0,0 @@
use crate::component::auth::LoggedUser;
use crate::component::ws::entities::{
Connect, Disconnect, MessageDetail, MessagePayload, Socket, WebSocketMessage,
};
use crate::component::ws::server::WSServer;
use crate::component::ws::{HEARTBEAT_INTERVAL, PING_TIMEOUT};
use actix::*;
use actix_http::ws::Message::*;
use actix_web::web::Data;
use actix_web_actors::ws;
use bytes::Bytes;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
pub trait MessageReceiver: Send + Sync {
fn receive(&self, data: WSClientData);
}
#[derive(Default)]
pub struct MessageReceivers {
inner: HashMap<u8, Arc<dyn MessageReceiver>>,
}
impl MessageReceivers {
pub fn new() -> Self {
MessageReceivers::default()
}
pub fn insert(&mut self, channel: u8, receiver: Arc<dyn MessageReceiver>) {
self.inner.insert(channel, receiver);
}
pub fn get(&self, source: u8) -> Option<&Arc<dyn MessageReceiver>> {
self.inner.get(&source)
}
}
#[allow(dead_code)]
pub struct WSClientData {
pub(crate) socket: Socket,
pub(crate) detail: MessageDetail,
}
pub struct WSClient {
user: Arc<LoggedUser>,
server: Addr<WSServer>,
msg_receivers: Data<MessageReceivers>,
hb: Instant,
}
impl WSClient {
pub fn new(
user: LoggedUser,
server: Addr<WSServer>,
msg_receivers: Data<MessageReceivers>,
) -> Self {
Self {
user: Arc::new(user),
server,
msg_receivers,
hb: Instant::now(),
}
}
fn hb(&self, ctx: &mut ws::WebsocketContext<Self>) {
ctx.run_interval(HEARTBEAT_INTERVAL, |client, ctx| {
if Instant::now().duration_since(client.hb) > PING_TIMEOUT {
client.server.do_send(Disconnect {
user: client.user.clone(),
});
ctx.stop();
} else {
ctx.ping(b"");
}
});
}
fn handle_binary_message(&self, bytes: Bytes, socket: Socket) {
let MessagePayload { channel, detail } = MessagePayload::from_bytes(&bytes);
match self.msg_receivers.get(channel) {
None => {
tracing::error!("Can't find the receiver for {:?}", channel);
}
Some(handler) => {
let client_data = WSClientData { socket, detail };
handler.receive(client_data);
}
}
}
}
impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for WSClient {
fn handle(&mut self, msg: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
match msg {
Ok(Ping(msg)) => {
self.hb = Instant::now();
ctx.pong(&msg);
}
Ok(Pong(_msg)) => {
// tracing::trace!("Receive {} pong {:?}", &self.session_id, &msg);
self.hb = Instant::now();
}
Ok(Binary(bytes)) => {
let socket = ctx.address().recipient();
self.handle_binary_message(bytes, socket);
}
Ok(Text(_)) => {
tracing::warn!("Receive unexpected text message");
}
Ok(Close(reason)) => {
ctx.close(reason);
ctx.stop();
}
Ok(ws::Message::Continuation(_)) => {}
Ok(ws::Message::Nop) => {}
Err(e) => {
tracing::error!("WebSocketStream protocol error {:?}", e);
ctx.stop();
}
}
}
}
impl Handler<WebSocketMessage> for WSClient {
type Result = ();
fn handle(&mut self, msg: WebSocketMessage, ctx: &mut Self::Context) {
ctx.binary(msg.0);
}
}
impl Actor for WSClient {
type Context = ws::WebsocketContext<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
self.hb(ctx);
let socket = ctx.address().recipient();
let connect = Connect {
socket,
user: self.user.clone(),
};
self.server
.send(connect)
.into_actor(self)
.then(|res, _client, _ctx| {
match res {
Ok(Ok(_)) => tracing::trace!("Send connect message to server success"),
Ok(Err(e)) => tracing::error!("Send connect message to server failed: {:?}", e),
Err(e) => tracing::error!("Send connect message to server failed: {:?}", e),
}
fut::ready(())
})
.wait(ctx);
}
fn stopping(&mut self, _: &mut Self::Context) -> Running {
self.server.do_send(Disconnect {
user: self.user.clone(),
});
Running::Stop
}
}

View file

@ -1,92 +0,0 @@
use crate::component::auth::LoggedUser;
use actix::{Message, Recipient};
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use std::fmt::Formatter;
use std::sync::Arc;
pub type Socket = Recipient<WebSocketMessage>;
#[derive(Serialize, Deserialize, Debug, Clone, Hash, PartialEq, Eq)]
pub struct WSSessionId(pub String);
impl<T: AsRef<str>> std::convert::From<T> for WSSessionId {
fn from(s: T) -> Self {
WSSessionId(s.as_ref().to_owned())
}
}
impl std::fmt::Display for WSSessionId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let desc = &self.0.to_string();
f.write_str(desc)
}
}
pub struct Session {
pub user: Arc<LoggedUser>,
pub socket: Socket,
}
impl std::convert::From<Connect> for Session {
fn from(c: Connect) -> Self {
Self {
user: c.user,
socket: c.socket,
}
}
}
#[derive(Debug, Message, Clone)]
#[rtype(result = "Result<(), WSError>")]
pub struct Connect {
pub socket: Socket,
pub user: Arc<LoggedUser>,
}
#[derive(Debug, Message, Clone)]
#[rtype(result = "Result<(), WSError>")]
pub struct Disconnect {
pub user: Arc<LoggedUser>,
}
#[derive(Debug, Message, Clone)]
#[rtype(result = "()")]
pub struct WebSocketMessage(pub Bytes);
impl std::ops::Deref for WebSocketMessage {
type Target = Bytes;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct MessagePayload {
pub(crate) channel: u8,
pub(crate) detail: MessageDetail,
}
impl MessagePayload {
pub fn from_bytes<T: AsRef<[u8]>>(bytes: T) -> Self {
serde_json::from_slice(bytes.as_ref()).unwrap()
}
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum MessageDetail {
Document(MessageContent),
Database(MessageContent),
}
#[derive(Debug, Serialize, Deserialize)]
pub struct MessageContent {
content: String,
}
#[derive(Debug)]
pub enum WSError {
Internal,
}

View file

@ -1,11 +0,0 @@
use std::time::Duration;
mod client;
mod entities;
mod server;
pub use client::*;
pub use server::WSServer;
pub(crate) const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(8);
pub(crate) const PING_TIMEOUT: Duration = Duration::from_secs(60);

View file

@ -1,44 +0,0 @@
use crate::component::ws::entities::{Connect, Disconnect, WSError, WebSocketMessage};
use actix::{Actor, Context, Handler};
#[derive(Default)]
pub struct WSServer {}
impl WSServer {
pub fn new() -> Self {
WSServer::default()
}
pub fn send(&self, _msg: WebSocketMessage) {}
}
impl Actor for WSServer {
type Context = Context<Self>;
fn started(&mut self, _ctx: &mut Self::Context) {}
}
impl Handler<Connect> for WSServer {
type Result = Result<(), WSError>;
fn handle(&mut self, _msg: Connect, _ctx: &mut Context<Self>) -> Self::Result {
Ok(())
}
}
impl Handler<Disconnect> for WSServer {
type Result = Result<(), WSError>;
fn handle(&mut self, _msg: Disconnect, _: &mut Context<Self>) -> Self::Result {
Ok(())
}
}
impl Handler<WebSocketMessage> for WSServer {
type Result = ();
fn handle(&mut self, _msg: WebSocketMessage, _ctx: &mut Context<Self>) -> Self::Result {}
}
impl actix::Supervised for WSServer {
fn restarting(&mut self, _ctx: &mut Context<WSServer>) {
tracing::warn!("restarting");
}
}

View file

@ -3,12 +3,13 @@ use secrecy::Secret;
use serde_aux::field_attributes::deserialize_number_from_string; use serde_aux::field_attributes::deserialize_number_from_string;
use sqlx::postgres::{PgConnectOptions, PgSslMode}; use sqlx::postgres::{PgConnectOptions, PgSslMode};
use std::convert::{TryFrom, TryInto}; use std::convert::{TryFrom, TryInto};
use std::path::PathBuf;
#[derive(serde::Deserialize, Clone, Debug)] #[derive(serde::Deserialize, Clone, Debug)]
pub struct Config { pub struct Config {
pub database: DatabaseSetting, pub database: DatabaseSetting,
pub application: ApplicationSettings, pub application: ApplicationSettings,
pub redis_uri: Secret<String>, pub redis_uri: Secret<String>,
} }
// We are using 127.0.0.1 as our host in address, we are instructing our // We are using 127.0.0.1 as our host in address, we are instructing our
@ -21,73 +22,78 @@ pub struct Config {
// //
#[derive(serde::Deserialize, Clone, Debug)] #[derive(serde::Deserialize, Clone, Debug)]
pub struct ApplicationSettings { pub struct ApplicationSettings {
#[serde(deserialize_with = "deserialize_number_from_string")] #[serde(deserialize_with = "deserialize_number_from_string")]
pub port: u16, pub port: u16,
pub host: String, pub host: String,
pub server_key: Secret<String>, pub data_dir: PathBuf,
pub tls_config: Option<TlsConfig>, pub server_key: Secret<String>,
pub tls_config: Option<TlsConfig>,
} }
impl ApplicationSettings { impl ApplicationSettings {
pub fn use_https(&self) -> bool { pub fn use_https(&self) -> bool {
match &self.tls_config { match &self.tls_config {
None => false, None => false,
Some(config) => match config { Some(config) => match config {
TlsConfig::NoTls => false, TlsConfig::NoTls => false,
TlsConfig::SelfSigned => true, TlsConfig::SelfSigned => true,
}, },
}
} }
}
pub fn rocksdb_db_dir(&self) -> PathBuf {
self.data_dir.join("rocksdb")
}
} }
#[derive(serde::Deserialize, Clone, Debug)] #[derive(serde::Deserialize, Clone, Debug)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum TlsConfig { pub enum TlsConfig {
NoTls, NoTls,
SelfSigned, SelfSigned,
} }
#[derive(serde::Deserialize, Clone, Debug)] #[derive(serde::Deserialize, Clone, Debug)]
pub struct DatabaseSetting { pub struct DatabaseSetting {
pub username: String, pub username: String,
pub password: String, pub password: String,
#[serde(deserialize_with = "deserialize_number_from_string")] #[serde(deserialize_with = "deserialize_number_from_string")]
pub port: u16, pub port: u16,
pub host: String, pub host: String,
pub database_name: String, pub database_name: String,
pub require_ssl: bool, pub require_ssl: bool,
} }
impl DatabaseSetting { impl DatabaseSetting {
pub fn without_db(&self) -> PgConnectOptions { pub fn without_db(&self) -> PgConnectOptions {
let ssl_mode = if self.require_ssl { let ssl_mode = if self.require_ssl {
PgSslMode::Require PgSslMode::Require
} else { } else {
PgSslMode::Prefer PgSslMode::Prefer
}; };
PgConnectOptions::new() PgConnectOptions::new()
.host(&self.host) .host(&self.host)
.username(&self.username) .username(&self.username)
.password(&self.password) .password(&self.password)
.port(self.port) .port(self.port)
.ssl_mode(ssl_mode) .ssl_mode(ssl_mode)
} }
pub fn with_db(&self) -> PgConnectOptions { pub fn with_db(&self) -> PgConnectOptions {
self.without_db().database(&self.database_name) self.without_db().database(&self.database_name)
} }
} }
pub fn get_configuration() -> Result<Config, config::ConfigError> { pub fn get_configuration() -> Result<Config, config::ConfigError> {
let base_path = std::env::current_dir().expect("Failed to determine the current directory"); let base_path = std::env::current_dir().expect("Failed to determine the current directory");
let configuration_dir = base_path.join("configuration"); let configuration_dir = base_path.join("configuration");
let environment: Environment = std::env::var("APP_ENVIRONMENT") let environment: Environment = std::env::var("APP_ENVIRONMENT")
.unwrap_or_else(|_| "local".into()) .unwrap_or_else(|_| "local".into())
.try_into() .try_into()
.expect("Failed to parse APP_ENVIRONMENT."); .expect("Failed to parse APP_ENVIRONMENT.");
let builder = InnerConfig::builder() let builder = InnerConfig::builder()
.set_default("default", "1")? .set_default("default", "1")?
.add_source( .add_source(
config::File::from(configuration_dir.join("base")) config::File::from(configuration_dir.join("base"))
@ -104,36 +110,36 @@ pub fn get_configuration() -> Result<Config, config::ConfigError> {
// `Settings.application.port` // `Settings.application.port`
.add_source(config::Environment::with_prefix("app").separator("__")); .add_source(config::Environment::with_prefix("app").separator("__"));
let config = builder.build()?; let config = builder.build()?;
config.try_deserialize() config.try_deserialize()
} }
/// The possible runtime environment for our application. /// The possible runtime environment for our application.
pub enum Environment { pub enum Environment {
Local, Local,
Production, Production,
} }
impl Environment { impl Environment {
pub fn as_str(&self) -> &'static str { pub fn as_str(&self) -> &'static str {
match self { match self {
Environment::Local => "local", Environment::Local => "local",
Environment::Production => "production", Environment::Production => "production",
}
} }
}
} }
impl TryFrom<String> for Environment { impl TryFrom<String> for Environment {
type Error = String; type Error = String;
fn try_from(s: String) -> Result<Self, Self::Error> { fn try_from(s: String) -> Result<Self, Self::Error> {
match s.to_lowercase().as_str() { match s.to_lowercase().as_str() {
"local" => Ok(Self::Local), "local" => Ok(Self::Local),
"production" => Ok(Self::Production), "production" => Ok(Self::Production),
other => Err(format!( other => Err(format!(
"{} is not a supported environment. Use either `local` or `production`.", "{} is not a supported environment. Use either `local` or `production`.",
other other
)), )),
}
} }
}
} }

View file

@ -1,13 +1,13 @@
use std::env; use std::env;
pub fn domain() -> String { pub fn domain() -> String {
env::var("DOMAIN").unwrap_or_else(|_| "localhost".to_string()) env::var("DOMAIN").unwrap_or_else(|_| "localhost".to_string())
} }
pub fn jwt_secret() -> String { pub fn jwt_secret() -> String {
env::var("JWT_SECRET").unwrap_or_else(|_| "my secret".into()) env::var("JWT_SECRET").unwrap_or_else(|_| "my secret".into())
} }
pub fn secret() -> String { pub fn secret() -> String {
env::var("SECRET_KEY").unwrap_or_else(|_| "0123".repeat(8)) env::var("SECRET_KEY").unwrap_or_else(|_| "0123".repeat(8))
} }

View file

@ -4,44 +4,44 @@ use validator::validate_email;
pub struct UserEmail(pub String); pub struct UserEmail(pub String);
impl UserEmail { impl UserEmail {
pub fn parse(s: String) -> Result<UserEmail, String> { pub fn parse(s: String) -> Result<UserEmail, String> {
if s.trim().is_empty() { if s.trim().is_empty() {
return Err("Email can not be empty or whitespace".to_string()); return Err("Email can not be empty or whitespace".to_string());
}
if validate_email(&s) {
Ok(Self(s))
} else {
Err("Invalid email".to_string())
}
} }
if validate_email(&s) {
Ok(Self(s))
} else {
Err("Invalid email".to_string())
}
}
} }
impl AsRef<str> for UserEmail { impl AsRef<str> for UserEmail {
fn as_ref(&self) -> &str { fn as_ref(&self) -> &str {
&self.0 &self.0
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test] #[test]
fn empty_string_is_rejected() { fn empty_string_is_rejected() {
let email = "".to_string(); let email = "".to_string();
assert!(UserEmail::parse(email).is_err()); assert!(UserEmail::parse(email).is_err());
} }
#[test] #[test]
fn email_missing_at_symbol_is_rejected() { fn email_missing_at_symbol_is_rejected() {
let email = "helloworld.com".to_string(); let email = "helloworld.com".to_string();
assert!(UserEmail::parse(email).is_err()); assert!(UserEmail::parse(email).is_err());
} }
#[test] #[test]
fn email_missing_subject_is_rejected() { fn email_missing_subject_is_rejected() {
let email = "@domain.com".to_string(); let email = "@domain.com".to_string();
assert!(UserEmail::parse(email).is_err()); assert!(UserEmail::parse(email).is_err());
} }
} }

View file

@ -4,78 +4,78 @@ use unicode_segmentation::UnicodeSegmentation;
pub struct UserName(pub String); pub struct UserName(pub String);
impl UserName { impl UserName {
pub fn parse(s: String) -> Result<UserName, String> { pub fn parse(s: String) -> Result<UserName, String> {
let is_empty_or_whitespace = s.trim().is_empty(); let is_empty_or_whitespace = s.trim().is_empty();
if is_empty_or_whitespace { if is_empty_or_whitespace {
return Err("User name can not be empty or whitespace".to_string()); return Err("User name can not be empty or whitespace".to_string());
}
// A grapheme is defined by the Unicode standard as a "user-perceived"
// character: `å` is a single grapheme, but it is composed of two characters
// (`a` and `̊`).
//
// `graphemes` returns an iterator over the graphemes in the input `s`.
// `true` specifies that we want to use the extended grapheme definition set,
// the recommended one.
let is_too_long = s.graphemes(true).count() > 256;
if is_too_long {
return Err("User name is too long".to_string());
}
let forbidden_characters = ['/', '(', ')', '"', '<', '>', '\\', '{', '}'];
let contains_forbidden_characters = s.chars().any(|g| forbidden_characters.contains(&g));
if contains_forbidden_characters {
return Err("User name contains invalid characters".to_string());
}
Ok(Self(s))
} }
// A grapheme is defined by the Unicode standard as a "user-perceived"
// character: `å` is a single grapheme, but it is composed of two characters
// (`a` and `̊`).
//
// `graphemes` returns an iterator over the graphemes in the input `s`.
// `true` specifies that we want to use the extended grapheme definition set,
// the recommended one.
let is_too_long = s.graphemes(true).count() > 256;
if is_too_long {
return Err("User name is too long".to_string());
}
let forbidden_characters = ['/', '(', ')', '"', '<', '>', '\\', '{', '}'];
let contains_forbidden_characters = s.chars().any(|g| forbidden_characters.contains(&g));
if contains_forbidden_characters {
return Err("User name contains invalid characters".to_string());
}
Ok(Self(s))
}
} }
impl AsRef<str> for UserName { impl AsRef<str> for UserName {
fn as_ref(&self) -> &str { fn as_ref(&self) -> &str {
&self.0 &self.0
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::UserName; use super::UserName;
#[test] #[test]
fn a_256_grapheme_long_name_is_valid() { fn a_256_grapheme_long_name_is_valid() {
let name = "".repeat(256); let name = "".repeat(256);
assert!(UserName::parse(name).is_ok()); assert!(UserName::parse(name).is_ok());
} }
#[test] #[test]
fn a_name_longer_than_256_graphemes_is_rejected() { fn a_name_longer_than_256_graphemes_is_rejected() {
let name = "a".repeat(257); let name = "a".repeat(257);
assert!(UserName::parse(name).is_err()); assert!(UserName::parse(name).is_err());
} }
#[test] #[test]
fn whitespace_only_names_are_rejected() { fn whitespace_only_names_are_rejected() {
let name = " ".to_string(); let name = " ".to_string();
assert!(UserName::parse(name).is_err()); assert!(UserName::parse(name).is_err());
} }
#[test] #[test]
fn empty_string_is_rejected() { fn empty_string_is_rejected() {
let name = "".to_string(); let name = "".to_string();
assert!(UserName::parse(name).is_err()); assert!(UserName::parse(name).is_err());
} }
#[test] #[test]
fn names_containing_an_invalid_character_are_rejected() { fn names_containing_an_invalid_character_are_rejected() {
for name in &['/', '(', ')', '"', '<', '>', '\\', '{', '}'] { for name in &['/', '(', ')', '"', '<', '>', '\\', '{', '}'] {
let name = name.to_string(); let name = name.to_string();
assert!(UserName::parse(name).is_err()); assert!(UserName::parse(name).is_err());
}
} }
}
#[test] #[test]
fn a_valid_name_is_parsed_successfully() { fn a_valid_name_is_parsed_successfully() {
let name = "nathan".to_string(); let name = "nathan".to_string();
assert!(UserName::parse(name).is_ok()); assert!(UserName::parse(name).is_ok());
} }
} }

View file

@ -6,33 +6,33 @@ use unicode_segmentation::UnicodeSegmentation;
pub struct UserPassword(pub String); pub struct UserPassword(pub String);
impl UserPassword { impl UserPassword {
pub fn parse(s: String) -> Result<UserPassword, String> { pub fn parse(s: String) -> Result<UserPassword, String> {
if s.trim().is_empty() { if s.trim().is_empty() {
return Err("User password can not be empty or whitespace".to_owned()); return Err("User password can not be empty or whitespace".to_owned());
}
if s.graphemes(true).count() > 100 {
return Err("Password is too long".to_owned());
}
let forbidden_characters = ['/', '(', ')', '"', '<', '>', '\\', '{', '}'];
let contains_forbidden_characters = s.chars().any(|g| forbidden_characters.contains(&g));
if contains_forbidden_characters {
return Err("Password contains invalid characters".to_string());
}
if !validate_password(&s) {
return Err("Password format invalid".to_string());
}
Ok(Self(s))
} }
if s.graphemes(true).count() > 100 {
return Err("Password is too long".to_owned());
}
let forbidden_characters = ['/', '(', ')', '"', '<', '>', '\\', '{', '}'];
let contains_forbidden_characters = s.chars().any(|g| forbidden_characters.contains(&g));
if contains_forbidden_characters {
return Err("Password contains invalid characters".to_string());
}
if !validate_password(&s) {
return Err("Password format invalid".to_string());
}
Ok(Self(s))
}
} }
impl AsRef<str> for UserPassword { impl AsRef<str> for UserPassword {
fn as_ref(&self) -> &str { fn as_ref(&self) -> &str {
&self.0 &self.0
} }
} }
lazy_static! { lazy_static! {
@ -53,11 +53,11 @@ lazy_static! {
} }
pub fn validate_password(password: &str) -> bool { pub fn validate_password(password: &str) -> bool {
match PASSWORD.is_match(password) { match PASSWORD.is_match(password) {
Ok(is_match) => is_match, Ok(is_match) => is_match,
Err(e) => { Err(e) => {
tracing::error!("validate_password fail: {:?}", e); tracing::error!("validate_password fail: {:?}", e);
false false
} },
} }
} }

View file

@ -4,13 +4,17 @@ use appflowy_server::telemetry::{get_subscriber, init_subscriber};
#[actix_web::main] #[actix_web::main]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
let subscriber = get_subscriber("appflowy_server".into(), "info".into(), std::io::stdout); let subscriber = get_subscriber(
init_subscriber(subscriber); "appflowy_server".to_string(),
"info".to_string(),
std::io::stdout,
);
init_subscriber(subscriber);
let configuration = get_configuration().expect("Failed to read configuration."); let configuration = get_configuration().expect("Failed to read configuration.");
let state = init_state(&configuration).await; let state = init_state(&configuration).await;
let application = Application::build(configuration, state).await?; let application = Application::build(configuration, state).await?;
application.run_until_stopped().await?; application.run_until_stopped().await?;
Ok(()) Ok(())
} }

View file

@ -6,11 +6,11 @@ use actix_web::http;
// http://www.ruanyifeng.com/blog/2016/04/cors.html // http://www.ruanyifeng.com/blog/2016/04/cors.html
// Cors short for Cross-Origin Resource Sharing. // Cors short for Cross-Origin Resource Sharing.
pub fn default_cors() -> Cors { pub fn default_cors() -> Cors {
Cors::default() // allowed_origin return access-control-allow-origin: * by default Cors::default() // allowed_origin return access-control-allow-origin: * by default
// .allowed_origin("http://127.0.0.1:8080") .allow_any_origin()
.send_wildcard() .send_wildcard()
.allowed_methods(vec!["GET", "POST", "PUT", "DELETE"]) .allowed_methods(vec!["GET", "POST", "PUT", "DELETE"])
.allowed_headers(vec![http::header::ACCEPT]) .allowed_headers(vec![http::header::ACCEPT])
.allowed_header(http::header::CONTENT_TYPE) .allowed_header(http::header::CONTENT_TYPE)
.max_age(3600) .max_age(3600)
} }

View file

@ -5,26 +5,26 @@ pub const CA_CRT: &str = include_str!("../cert/cert.pem");
pub const CA_KEY: &str = include_str!("../cert/key.pem"); pub const CA_KEY: &str = include_str!("../cert/key.pem");
pub fn create_self_signed_certificate() -> Result<(Secret<String>, Secret<String>), RcgenError> { pub fn create_self_signed_certificate() -> Result<(Secret<String>, Secret<String>), RcgenError> {
let key = KeyPair::from_pem(CA_KEY)?; let key = KeyPair::from_pem(CA_KEY)?;
let params = CertificateParams::from_ca_cert_pem(CA_CRT, key)?; let params = CertificateParams::from_ca_cert_pem(CA_CRT, key)?;
let ca_cert = Certificate::from_params(params)?; let ca_cert = Certificate::from_params(params)?;
let mut params = CertificateParams::default(); let mut params = CertificateParams::default();
params params
.subject_alt_names .subject_alt_names
.push(SanType::IpAddress("127.0.0.1".parse().unwrap())); .push(SanType::IpAddress("127.0.0.1".parse().unwrap()));
params params
.subject_alt_names .subject_alt_names
.push(SanType::IpAddress("0.0.0.0".parse().unwrap())); .push(SanType::IpAddress("0.0.0.0".parse().unwrap()));
params params
.subject_alt_names .subject_alt_names
.push(SanType::DnsName("localhost".to_string())); .push(SanType::DnsName("localhost".to_string()));
// Generate a certificate that's valid for: // Generate a certificate that's valid for:
// 1. localhost // 1. localhost
// 2. 127.0.0.1 // 2. 127.0.0.1
let gen_cert = Certificate::from_params(params)?; let gen_cert = Certificate::from_params(params)?;
let server_crt = Secret::new(gen_cert.serialize_pem_with_signer(&ca_cert)?); let server_crt = Secret::new(gen_cert.serialize_pem_with_signer(&ca_cert)?);
let server_key = Secret::new(gen_cert.serialize_private_key_pem()); let server_key = Secret::new(gen_cert.serialize_private_key_pem());
Ok((server_crt, server_key)) Ok((server_crt, server_key))
} }

View file

@ -1,6 +1,8 @@
use crate::component::auth::LoggedUser; use crate::component::auth::LoggedUser;
use crate::config::config::Config; use crate::config::config::Config;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use collab_persistence::kv::rocks_kv::RocksCollabDB;
use snowflake::Snowflake; use snowflake::Snowflake;
use sqlx::PgPool; use sqlx::PgPool;
use std::collections::BTreeMap; use std::collections::BTreeMap;
@ -9,70 +11,72 @@ use tokio::sync::RwLock;
#[derive(Clone)] #[derive(Clone)]
pub struct State { pub struct State {
pub pg_pool: PgPool, pub pg_pool: PgPool,
pub config: Arc<Config>, pub rocksdb: Arc<RocksCollabDB>,
pub user: Arc<RwLock<UserCache>>, pub config: Arc<Config>,
pub id_gen: Arc<RwLock<Snowflake>>, pub user: Arc<RwLock<UserCache>>,
pub id_gen: Arc<RwLock<Snowflake>>,
} }
impl State { impl State {
pub async fn load_users(_pool: &PgPool) { pub async fn load_users(_pool: &PgPool) {
todo!() todo!()
} }
pub async fn next_user_id(&self) -> i64 { pub async fn next_user_id(&self) -> i64 {
self.id_gen.write().await.next_id() self.id_gen.write().await.next_id()
} }
} }
#[derive(Clone, Debug, Copy)] #[derive(Clone, Debug, Copy)]
enum AuthStatus { enum AuthStatus {
Authorized(DateTime<Utc>), Authorized(DateTime<Utc>),
NotAuthorized, NotAuthorized,
} }
pub const EXPIRED_DURATION_DAYS: i64 = 30; pub const EXPIRED_DURATION_DAYS: i64 = 30;
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct UserCache { pub struct UserCache {
// Keep track the user authentication state // Keep track the user authentication state
user: BTreeMap<i64, AuthStatus>, user: BTreeMap<i64, AuthStatus>,
} }
impl UserCache { impl UserCache {
pub fn new() -> Self { pub fn new() -> Self {
UserCache::default() UserCache::default()
} }
pub fn is_authorized(&self, user: &LoggedUser) -> bool { pub fn is_authorized(&self, user: &LoggedUser) -> bool {
match self.user.get(user.expose_secret()) { match self.user.get(user.expose_secret()) {
None => { None => {
tracing::debug!("user not login yet or server was reboot"); tracing::debug!("user not login yet or server was reboot");
false false
} },
Some(status) => match *status { Some(status) => match *status {
AuthStatus::Authorized(last_time) => { AuthStatus::Authorized(last_time) => {
let current_time = Utc::now(); let current_time = Utc::now();
let days = (current_time - last_time).num_days(); let days = (current_time - last_time).num_days();
days < EXPIRED_DURATION_DAYS days < EXPIRED_DURATION_DAYS
} },
AuthStatus::NotAuthorized => { AuthStatus::NotAuthorized => {
tracing::debug!("user logout already"); tracing::debug!("user logout already");
false false
} },
}, },
}
} }
}
pub fn authorized(&mut self, user: LoggedUser) { pub fn authorized(&mut self, user: LoggedUser) {
self.user.insert( self.user.insert(
user.expose_secret().to_owned(), user.expose_secret().to_owned(),
AuthStatus::Authorized(Utc::now()), AuthStatus::Authorized(Utc::now()),
); );
} }
pub fn unauthorized(&mut self, user: LoggedUser) { pub fn unauthorized(&mut self, user: LoggedUser) {
self.user self
.insert(user.expose_secret().to_owned(), AuthStatus::NotAuthorized); .user
} .insert(user.expose_secret().to_owned(), AuthStatus::NotAuthorized);
}
} }

View file

@ -4,39 +4,41 @@ use tracing::Subscriber;
use tracing_bunyan_formatter::{BunyanFormattingLayer, JsonStorageLayer}; use tracing_bunyan_formatter::{BunyanFormattingLayer, JsonStorageLayer};
use tracing_log::LogTracer; use tracing_log::LogTracer;
use tracing_subscriber::fmt::MakeWriter; use tracing_subscriber::fmt::MakeWriter;
use tracing_subscriber::{layer::SubscriberExt, EnvFilter, Registry}; use tracing_subscriber::{layer::SubscriberExt, EnvFilter};
/// Compose multiple layers into a `tracing`'s subscriber. /// Compose multiple layers into a `tracing`'s subscriber.
pub fn get_subscriber<Sink>( pub fn get_subscriber<Sink>(
name: String, name: String,
env_filter: String, env_filter: String,
sink: Sink, sink: Sink,
) -> impl Subscriber + Sync + Send ) -> impl Subscriber + Sync + Send
where where
Sink: for<'a> MakeWriter<'a> + Send + Sync + 'static, Sink: for<'a> MakeWriter<'a> + Send + Sync + 'static,
{ {
let env_filter = let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(env_filter));
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(env_filter)); // let env_filter = EnvFilter::new(env_filter);
let formatting_layer = BunyanFormattingLayer::new(name, sink); let formatting_layer = BunyanFormattingLayer::new(name, sink);
Registry::default() tracing_subscriber::fmt()
.with(env_filter) .with_ansi(true)
.with(JsonStorageLayer) .finish()
.with(formatting_layer) .with(env_filter)
.with(JsonStorageLayer)
.with(formatting_layer)
} }
/// Register a subscriber as global default to process span data. /// Register a subscriber as global default to process span data.
/// ///
/// It should only be called once! /// It should only be called once!
pub fn init_subscriber(subscriber: impl Subscriber + Sync + Send) { pub fn init_subscriber(subscriber: impl Subscriber + Sync + Send) {
LogTracer::init().expect("Failed to set logger"); LogTracer::init().expect("Failed to set logger");
set_global_default(subscriber).expect("Failed to set subscriber"); set_global_default(subscriber).expect("Failed to set subscriber");
} }
pub fn spawn_blocking_with_tracing<F, R>(f: F) -> JoinHandle<R> pub fn spawn_blocking_with_tracing<F, R>(f: F) -> JoinHandle<R>
where where
F: FnOnce() -> R + Send + 'static, F: FnOnce() -> R + Send + 'static,
R: Send + 'static, R: Send + 'static,
{ {
let current_span = tracing::Span::current(); let current_span = tracing::Span::current();
actix_web::rt::task::spawn_blocking(move || current_span.in_scope(f)) actix_web::rt::task::spawn_blocking(move || current_span.in_scope(f))
} }

View file

@ -1,44 +1,44 @@
use crate::test_server::{spawn_server, TestUser}; use crate::util::{spawn_server, TestUser};
use actix_web::http::StatusCode; use actix_web::http::StatusCode;
use appflowy_server::component::auth::LoginResponse; use appflowy_server::component::auth::LoginResponse;
#[tokio::test] #[actix_rt::test]
async fn login_success() { async fn login_success() {
let server = spawn_server().await; let server = spawn_server().await;
let test_user = TestUser::generate(); let test_user = TestUser::generate();
test_user.register(&server).await; test_user.register(&server).await;
let http_resp = server.login(&test_user.email, &test_user.password).await; let http_resp = server.login(&test_user.email, &test_user.password).await;
assert_eq!(http_resp.status(), StatusCode::OK); assert_eq!(http_resp.status(), StatusCode::OK);
let bytes = http_resp.bytes().await.unwrap(); let bytes = http_resp.bytes().await.unwrap();
let response: LoginResponse = serde_json::from_slice(&bytes).unwrap(); let response: LoginResponse = serde_json::from_slice(&bytes).unwrap();
assert!(!response.token.is_empty()) assert!(!response.token.is_empty())
} }
#[tokio::test] #[actix_rt::test]
async fn login_with_empty_email() { async fn login_with_empty_email() {
let server = spawn_server().await; let server = spawn_server().await;
let test_user = TestUser::generate(); let test_user = TestUser::generate();
test_user.register(&server).await; test_user.register(&server).await;
let http_resp = server.login("", &test_user.password).await; let http_resp = server.login("", &test_user.password).await;
assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST); assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST);
} }
#[tokio::test] #[actix_rt::test]
async fn login_with_empty_password() { async fn login_with_empty_password() {
let server = spawn_server().await; let server = spawn_server().await;
let test_user = TestUser::generate(); let test_user = TestUser::generate();
test_user.register(&server).await; test_user.register(&server).await;
let http_resp = server.login(&test_user.email, "").await; let http_resp = server.login(&test_user.email, "").await;
assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST); assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST);
} }
#[tokio::test] #[actix_rt::test]
async fn login_with_unknown_user() { async fn login_with_unknown_user() {
let server = spawn_server().await; let server = spawn_server().await;
let http_resp = server.login("unknown@appflowy.io", "Abc@123!").await; let http_resp = server.login("unknown@appflowy.io", "Abc@123!").await;
assert_eq!(http_resp.status(), StatusCode::UNAUTHORIZED); assert_eq!(http_resp.status(), StatusCode::UNAUTHORIZED);
} }

View file

@ -1,4 +1,4 @@
mod login; mod login;
mod password; mod password;
mod register; mod register;
mod test_server; mod ws;

View file

@ -1,53 +1,53 @@
use crate::test_server::{spawn_server, TestUser}; use crate::util::{spawn_server, TestUser};
use actix_web::http::StatusCode; use actix_web::http::StatusCode;
#[tokio::test] #[actix_rt::test]
async fn change_password_with_unmatched_password() { async fn change_password_with_unmatched_password() {
let server = spawn_server().await; let server = spawn_server().await;
let test_user = TestUser::generate(); let test_user = TestUser::generate();
let token = test_user.register(&server).await; let token = test_user.register(&server).await;
let new_password = "HelloWorld@1a"; let new_password = "HelloWorld@1a";
let new_password_confirm = "HeloWorld@1a"; let new_password_confirm = "HeloWorld@1a";
let http_resp = server let http_resp = server
.change_password( .change_password(
token, token,
&test_user.password, &test_user.password,
new_password, new_password,
new_password_confirm, new_password_confirm,
) )
.await; .await;
assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST); assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST);
} }
#[tokio::test] #[actix_rt::test]
async fn login_fail_after_change_password() { async fn login_fail_after_change_password() {
let server = spawn_server().await; let server = spawn_server().await;
let test_user = TestUser::generate(); let test_user = TestUser::generate();
let token = test_user.register(&server).await; let token = test_user.register(&server).await;
let new_password = "HelloWorld@1a"; let new_password = "HelloWorld@1a";
let http_resp = server let http_resp = server
.change_password(token, &test_user.password, new_password, new_password) .change_password(token, &test_user.password, new_password, new_password)
.await; .await;
assert_eq!(http_resp.status(), StatusCode::OK); assert_eq!(http_resp.status(), StatusCode::OK);
let http_resp = server.login(&test_user.email, &test_user.password).await; let http_resp = server.login(&test_user.email, &test_user.password).await;
assert_eq!(http_resp.status(), StatusCode::UNAUTHORIZED); assert_eq!(http_resp.status(), StatusCode::UNAUTHORIZED);
} }
#[tokio::test] #[actix_rt::test]
async fn login_success_with_new_password() { async fn login_success_with_new_password() {
let server = spawn_server().await; let server = spawn_server().await;
let test_user = TestUser::generate(); let test_user = TestUser::generate();
let token = test_user.register(&server).await; let token = test_user.register(&server).await;
let new_password = "HelloWorld@1a"; let new_password = "HelloWorld@1a";
let http_resp = server let http_resp = server
.change_password(token, &test_user.password, new_password, new_password) .change_password(token, &test_user.password, new_password, new_password)
.await; .await;
assert_eq!(http_resp.status(), StatusCode::OK); assert_eq!(http_resp.status(), StatusCode::OK);
let http_resp = server.login(&test_user.email, new_password).await; let http_resp = server.login(&test_user.email, new_password).await;
assert_eq!(http_resp.status(), StatusCode::OK); assert_eq!(http_resp.status(), StatusCode::OK);
} }

View file

@ -1,53 +1,53 @@
use crate::test_server::{error_msg_from_resp, spawn_server}; use crate::util::{error_msg_from_resp, spawn_server};
use appflowy_server::component::auth::{InputParamsError, RegisterResponse}; use appflowy_server::component::auth::{InputParamsError, RegisterResponse};
use reqwest::StatusCode; use reqwest::StatusCode;
#[tokio::test] #[actix_rt::test]
// curl -X POST --url http://0.0.0.0:8000/api/user/register --header 'content-type: application/json' --data '{"name":"fake name", "email":"fake@appflowy.io", "password":"Fake@123"}' // curl -X POST --url http://0.0.0.0:8000/api/user/register --header 'content-type: application/json' --data '{"name":"fake name", "email":"fake@appflowy.io", "password":"Fake@123"}'
async fn register_success() { async fn register_success() {
let server = spawn_server().await; let server = spawn_server().await;
let http_resp = server let http_resp = server
.register("user 1", "fake@appflowy.io", "FakePassword!123") .register("user 1", "fake@appflowy.io", "FakePassword!123")
.await; .await;
let bytes = http_resp.bytes().await.unwrap(); let bytes = http_resp.bytes().await.unwrap();
let response: RegisterResponse = serde_json::from_slice(&bytes).unwrap(); let response: RegisterResponse = serde_json::from_slice(&bytes).unwrap();
println!("{:?}", response); println!("{:?}", response);
} }
#[tokio::test] #[actix_rt::test]
async fn register_with_invalid_password() { async fn register_with_invalid_password() {
let server = spawn_server().await; let server = spawn_server().await;
let http_resp = server.register("user 1", "fake@appflowy.io", "123").await; let http_resp = server.register("user 1", "fake@appflowy.io", "123").await;
assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST); assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST);
assert_eq!( assert_eq!(
error_msg_from_resp(http_resp).await, error_msg_from_resp(http_resp).await,
InputParamsError::InvalidPassword.to_string() InputParamsError::InvalidPassword.to_string()
); );
} }
#[tokio::test] #[actix_rt::test]
async fn register_with_invalid_name() { async fn register_with_invalid_name() {
let server = spawn_server().await; let server = spawn_server().await;
let name = "".to_string(); let name = "".to_string();
let http_resp = server let http_resp = server
.register(&name, "fake@appflowy.io", "FakePassword!123") .register(&name, "fake@appflowy.io", "FakePassword!123")
.await; .await;
assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST); assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST);
assert_eq!( assert_eq!(
error_msg_from_resp(http_resp).await, error_msg_from_resp(http_resp).await,
InputParamsError::InvalidName(name).to_string() InputParamsError::InvalidName(name).to_string()
); );
} }
#[tokio::test] #[actix_rt::test]
async fn register_with_invalid_email() { async fn register_with_invalid_email() {
let server = spawn_server().await; let server = spawn_server().await;
let email = "appflowy.io".to_string(); let email = "appflowy.io".to_string();
let http_resp = server.register("me", &email, "FakePassword!123").await; let http_resp = server.register("me", &email, "FakePassword!123").await;
assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST); assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST);
assert_eq!( assert_eq!(
error_msg_from_resp(http_resp).await, error_msg_from_resp(http_resp).await,
InputParamsError::InvalidEmail(email).to_string() InputParamsError::InvalidEmail(email).to_string()
); );
} }

View file

@ -1,189 +0,0 @@
use appflowy_server::application::{init_state, Application};
use appflowy_server::config::config::{get_configuration, DatabaseSetting, TlsConfig};
use appflowy_server::state::State;
use appflowy_server::telemetry::{get_subscriber, init_subscriber};
use once_cell::sync::Lazy;
use reqwest::Certificate;
use appflowy_server::component::auth::{RegisterResponse, HEADER_TOKEN};
use sqlx::types::Uuid;
use sqlx::{Connection, Executor, PgConnection, PgPool};
// Ensure that the `tracing` stack is only initialised once using `once_cell`
static TRACING: Lazy<()> = Lazy::new(|| {
let level = "info".to_string();
let mut filters = vec![];
filters.push(format!("appflowy_server={}", level));
filters.push(format!("hyper={}", level));
let subscriber_name = "test".to_string();
let subscriber = get_subscriber(subscriber_name, filters.join(","), std::io::stdout);
init_subscriber(subscriber);
});
#[derive(Clone)]
pub struct TestServer {
pub state: State,
pub api_client: reqwest::Client,
pub address: String,
pub port: u16,
}
impl TestServer {
pub async fn register(&self, name: &str, email: &str, password: &str) -> reqwest::Response {
let payload = serde_json::json!({
"name": name,
"password": password,
"email": email
});
let url = format!("{}/api/user/register", self.address);
self.api_client
.post(&url)
.json(&payload)
.send()
.await
.expect("Register failed")
}
pub async fn login(&self, email: &str, password: &str) -> reqwest::Response {
let payload = serde_json::json!({
"password": password,
"email": email
});
let url = format!("{}/api/user/login", self.address);
self.api_client
.post(&url)
.json(&payload)
.send()
.await
.expect("Login failed")
}
pub async fn change_password(
&self,
token: String,
current_password: &str,
new_password: &str,
new_password_confirm: &str,
) -> reqwest::Response {
let payload = serde_json::json!({
"current_password": current_password,
"new_password": new_password,
"new_password_confirm": new_password_confirm
});
let url = format!("{}/api/user/password", self.address);
self.api_client
.post(&url)
.header(HEADER_TOKEN, token)
.json(&payload)
.send()
.await
.expect("Change password failed")
}
}
pub async fn spawn_server() -> TestServer {
Lazy::force(&TRACING);
let database_name = Uuid::new_v4().to_string();
let config = {
let mut config = get_configuration().expect("Failed to read configuration.");
config.database.database_name = database_name.clone();
// Use a random OS port
config.application.port = 0;
config
};
let _ = configure_database(&config.database).await;
let state = init_state(&config).await;
let application = Application::build(config.clone(), state.clone())
.await
.expect("Failed to build application");
let port = application.port();
let _ = tokio::spawn(async {
let _ = application.run_until_stopped().await;
});
let mut builder = reqwest::Client::builder();
let mut address = format!("http://localhost:{}", port);
if config.application.use_https() {
address = format!("https://localhost:{}", port);
builder = builder.add_root_certificate(
Certificate::from_pem(include_bytes!("../../cert/cert.pem")).unwrap(),
);
}
let api_client = builder
.add_root_certificate(Certificate::from_pem(include_bytes!("../../cert/cert.pem")).unwrap())
.redirect(reqwest::redirect::Policy::none())
.danger_accept_invalid_certs(true)
.cookie_store(true)
.no_proxy()
.build()
.unwrap();
TestServer {
state,
api_client,
address,
port,
}
}
async fn configure_database(config: &DatabaseSetting) -> PgPool {
// Create database
let mut connection = PgConnection::connect_with(&config.without_db())
.await
.expect("Failed to connect to Postgres");
connection
.execute(&*format!(r#"CREATE DATABASE "{}";"#, config.database_name))
.await
.expect("Failed to create database.");
// Migrate database
let connection_pool = PgPool::connect_with(config.with_db())
.await
.expect("Failed to connect to Postgres.");
sqlx::migrate!("./migrations")
.run(&connection_pool)
.await
.expect("Failed to migrate the database");
connection_pool
}
#[derive(serde::Serialize)]
pub struct TestUser {
name: String,
pub email: String,
pub password: String,
}
impl TestUser {
pub fn generate() -> Self {
Self {
name: "Me".to_string(),
email: "me@appflowy.io".to_string(),
password: "Hello@AppFlowy123".to_string(),
}
}
pub async fn register(&self, test_server: &TestServer) -> String {
let url = format!("{}/api/user/register", test_server.address);
let resp = test_server
.api_client
.post(&url)
.json(self)
.send()
.await
.expect("Fail to register user");
let bytes = resp.bytes().await.unwrap();
let response: RegisterResponse = serde_json::from_slice(&bytes).unwrap();
response.token
}
}
pub async fn error_msg_from_resp(resp: reqwest::Response) -> String {
let bytes = resp.bytes().await.unwrap();
String::from_utf8(bytes.to_vec()).unwrap()
}

13
tests/api/ws.rs Normal file
View file

@ -0,0 +1,13 @@
use crate::util::{spawn_server, TestUser};
use collab_client_ws::WSClient;
#[actix_rt::test]
async fn ws_conn_test() {
let server = spawn_server().await;
let test_user = TestUser::generate();
let token = test_user.register(&server).await;
let address = format!("{}/{}", server.ws_addr, token);
let client = WSClient::new(address, 100);
let _ = client.connect().await;
}

3
tests/main.rs Normal file
View file

@ -0,0 +1,3 @@
mod api;
mod util;
mod ws;

2
tests/util/mod.rs Normal file
View file

@ -0,0 +1,2 @@
mod test_server;
pub use test_server::*;

232
tests/util/test_server.rs Normal file
View file

@ -0,0 +1,232 @@
use appflowy_server::application::{init_state, Application};
use appflowy_server::config::config::{get_configuration, DatabaseSetting};
use appflowy_server::state::State;
use appflowy_server::telemetry::{get_subscriber, init_subscriber};
use once_cell::sync::Lazy;
use reqwest::Certificate;
use std::path::PathBuf;
use appflowy_server::component::auth::{RegisterResponse, HEADER_TOKEN};
use sqlx::types::Uuid;
use sqlx::{Connection, Executor, PgConnection, PgPool};
// Ensure that the `tracing` stack is only initialised once using `once_cell`
static TRACING: Lazy<()> = Lazy::new(|| {
let level = "trace".to_string();
let mut filters = vec![];
filters.push(format!("appflowy_server={}", level));
filters.push(format!("collab_client_ws={}", level));
filters.push(format!("hyper={}", level));
filters.push(format!("actix_web={}", level));
let subscriber_name = "test".to_string();
let subscriber = get_subscriber(subscriber_name, filters.join(","), std::io::stdout);
init_subscriber(subscriber);
});
#[derive(Clone)]
pub struct TestServer {
pub state: State,
pub api_client: reqwest::Client,
pub address: String,
pub port: u16,
pub ws_addr: String,
#[allow(dead_code)]
pub cleaner: Cleaner,
}
impl TestServer {
pub async fn register(&self, name: &str, email: &str, password: &str) -> reqwest::Response {
let payload = serde_json::json!({
"name": name,
"password": password,
"email": email
});
let url = format!("{}/api/user/register", self.address);
self
.api_client
.post(&url)
.json(&payload)
.send()
.await
.expect("Register failed")
}
pub async fn login(&self, email: &str, password: &str) -> reqwest::Response {
let payload = serde_json::json!({
"password": password,
"email": email
});
let url = format!("{}/api/user/login", self.address);
self
.api_client
.post(&url)
.json(&payload)
.send()
.await
.expect("Login failed")
}
pub async fn change_password(
&self,
token: String,
current_password: &str,
new_password: &str,
new_password_confirm: &str,
) -> reqwest::Response {
let payload = serde_json::json!({
"current_password": current_password,
"new_password": new_password,
"new_password_confirm": new_password_confirm
});
let url = format!("{}/api/user/password", self.address);
self
.api_client
.post(&url)
.header(HEADER_TOKEN, token)
.json(&payload)
.send()
.await
.expect("Change password failed")
}
}
pub async fn spawn_server() -> TestServer {
Lazy::force(&TRACING);
let database_name = Uuid::new_v4().to_string();
let config = {
let mut config = get_configuration().expect("Failed to read configuration.");
config.database.database_name = database_name.clone();
// Use a random OS port
config.application.port = 0;
config.application.data_dir = PathBuf::from(format!("./data/{}", database_name));
config
};
let _ = configure_database(&config.database).await;
let state = init_state(&config).await;
let application = Application::build(config.clone(), state.clone())
.await
.expect("Failed to build application");
let port = application.port();
let _ = tokio::spawn(async {
let _ = application.run_until_stopped().await;
});
let mut builder = reqwest::Client::builder();
let mut address = format!("http://localhost:{}", port);
let mut ws_addr = format!("ws://localhost:{}/ws", port);
if config.application.use_https() {
address = format!("https://localhost:{}", port);
ws_addr = format!("wss://localhost:{}/ws", port);
builder = builder
.add_root_certificate(Certificate::from_pem(include_bytes!("../../cert/cert.pem")).unwrap());
}
let api_client = builder
.add_root_certificate(Certificate::from_pem(include_bytes!("../../cert/cert.pem")).unwrap())
.redirect(reqwest::redirect::Policy::none())
.danger_accept_invalid_certs(true)
.cookie_store(true)
.no_proxy()
.build()
.unwrap();
let cleaner = Cleaner::new(config.application.data_dir);
TestServer {
state,
api_client,
address,
ws_addr,
port,
cleaner,
}
}
async fn configure_database(config: &DatabaseSetting) -> PgPool {
// Create database
let mut connection = PgConnection::connect_with(&config.without_db())
.await
.expect("Failed to connect to Postgres");
connection
.execute(&*format!(r#"CREATE DATABASE "{}";"#, config.database_name))
.await
.expect("Failed to create database.");
// Migrate database
let connection_pool = PgPool::connect_with(config.with_db())
.await
.expect("Failed to connect to Postgres.");
sqlx::migrate!("./migrations")
.run(&connection_pool)
.await
.expect("Failed to migrate the database");
connection_pool
}
#[derive(serde::Serialize)]
pub struct TestUser {
name: String,
pub email: String,
pub password: String,
}
impl TestUser {
pub fn generate() -> Self {
Self {
name: "Me".to_string(),
email: "me@appflowy.io".to_string(),
password: "Hello@AppFlowy123".to_string(),
}
}
pub async fn register(&self, test_server: &TestServer) -> String {
let url = format!("{}/api/user/register", test_server.address);
let resp = test_server
.api_client
.post(&url)
.json(self)
.send()
.await
.expect("Fail to register user");
let bytes = resp.bytes().await.unwrap();
let response: RegisterResponse = serde_json::from_slice(&bytes).unwrap();
response.token
}
}
pub async fn error_msg_from_resp(resp: reqwest::Response) -> String {
let bytes = resp.bytes().await.unwrap();
String::from_utf8(bytes.to_vec()).unwrap()
}
#[derive(Clone)]
pub struct Cleaner {
path: PathBuf,
should_clean: bool,
}
impl Cleaner {
fn new(path: PathBuf) -> Self {
Self {
path,
should_clean: true,
}
}
fn cleanup(dir: &PathBuf) {
let _ = std::fs::remove_dir_all(dir);
}
}
impl Drop for Cleaner {
fn drop(&mut self) {
if self.should_clean {
Self::cleanup(&self.path)
}
}
}

1
tests/ws/mod.rs Normal file
View file

@ -0,0 +1 @@