mirror of
https://github.com/AppFlowy-IO/AppFlowy-Cloud.git
synced 2025-04-19 03:24:42 -04:00
feat: ws connect (#3)
* chore: ws * chore: build client stream * feat: test ws connect * ci: fix ci
This commit is contained in:
parent
08847fad1d
commit
18e950a829
55 changed files with 2144 additions and 1868 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -8,4 +8,6 @@
|
||||||
**/temp/**
|
**/temp/**
|
||||||
package-lock.json
|
package-lock.json
|
||||||
yarn.lock
|
yarn.lock
|
||||||
node_modules
|
node_modules
|
||||||
|
**/crates/AppFlowy-Collab/
|
||||||
|
data/
|
449
Cargo.lock
generated
449
Cargo.lock
generated
|
@ -89,7 +89,7 @@ dependencies = [
|
||||||
"mime",
|
"mime",
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"rand",
|
"rand 0.8.5",
|
||||||
"sha1",
|
"sha1",
|
||||||
"smallvec",
|
"smallvec",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
@ -189,7 +189,7 @@ dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"derive_more",
|
"derive_more",
|
||||||
"rand",
|
"rand 0.8.5",
|
||||||
"redis",
|
"redis",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
@ -370,7 +370,7 @@ version = "0.7.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47"
|
checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"getrandom",
|
"getrandom 0.2.9",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"version_check",
|
"version_check",
|
||||||
]
|
]
|
||||||
|
@ -382,7 +382,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f"
|
checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"getrandom",
|
"getrandom 0.2.9",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"version_check",
|
"version_check",
|
||||||
]
|
]
|
||||||
|
@ -446,6 +446,9 @@ dependencies = [
|
||||||
"bincode",
|
"bincode",
|
||||||
"bytes",
|
"bytes",
|
||||||
"chrono",
|
"chrono",
|
||||||
|
"collab-client-ws",
|
||||||
|
"collab-persistence",
|
||||||
|
"collab-sync",
|
||||||
"config",
|
"config",
|
||||||
"dashmap",
|
"dashmap",
|
||||||
"derive_more",
|
"derive_more",
|
||||||
|
@ -454,7 +457,7 @@ dependencies = [
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"openssl",
|
"openssl",
|
||||||
"rand",
|
"rand 0.8.5",
|
||||||
"rcgen",
|
"rcgen",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"secrecy",
|
"secrecy",
|
||||||
|
@ -474,6 +477,7 @@ dependencies = [
|
||||||
"unicode-segmentation",
|
"unicode-segmentation",
|
||||||
"uuid",
|
"uuid",
|
||||||
"validator",
|
"validator",
|
||||||
|
"websocket",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -574,6 +578,12 @@ dependencies = [
|
||||||
"num-traits",
|
"num-traits",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "atomic_refcell"
|
||||||
|
version = "0.1.10"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "79d6dc922a2792b006573f60b2648076355daeae5ce9cb59507e5908c9625d31"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "autocfg"
|
name = "autocfg"
|
||||||
version = "1.1.0"
|
version = "1.1.0"
|
||||||
|
@ -613,6 +623,26 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bindgen"
|
||||||
|
version = "0.64.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c4243e6031260db77ede97ad86c27e501d646a27ab57b59a574f725d98ab1fb4"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags",
|
||||||
|
"cexpr",
|
||||||
|
"clang-sys",
|
||||||
|
"lazy_static",
|
||||||
|
"lazycell",
|
||||||
|
"peeking_take_while",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"regex",
|
||||||
|
"rustc-hash",
|
||||||
|
"shlex",
|
||||||
|
"syn 1.0.109",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bit-set"
|
name = "bit-set"
|
||||||
version = "0.5.3"
|
version = "0.5.3"
|
||||||
|
@ -700,6 +730,17 @@ dependencies = [
|
||||||
"bytes",
|
"bytes",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bzip2-sys"
|
||||||
|
version = "0.1.11+1.0.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc"
|
||||||
|
dependencies = [
|
||||||
|
"cc",
|
||||||
|
"libc",
|
||||||
|
"pkg-config",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cc"
|
name = "cc"
|
||||||
version = "1.0.79"
|
version = "1.0.79"
|
||||||
|
@ -709,6 +750,15 @@ dependencies = [
|
||||||
"jobserver",
|
"jobserver",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cexpr"
|
||||||
|
version = "0.6.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
|
||||||
|
dependencies = [
|
||||||
|
"nom",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cfg-if"
|
name = "cfg-if"
|
||||||
version = "1.0.0"
|
version = "1.0.0"
|
||||||
|
@ -738,6 +788,17 @@ dependencies = [
|
||||||
"inout",
|
"inout",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "clang-sys"
|
||||||
|
version = "1.6.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c688fc74432808e3eb684cae8830a86be1d66a2bd58e1f248ed0960a590baf6f"
|
||||||
|
dependencies = [
|
||||||
|
"glob",
|
||||||
|
"libc",
|
||||||
|
"libloading",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "codespan-reporting"
|
name = "codespan-reporting"
|
||||||
version = "0.11.1"
|
version = "0.11.1"
|
||||||
|
@ -748,6 +809,91 @@ dependencies = [
|
||||||
"unicode-width",
|
"unicode-width",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "collab"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"bytes",
|
||||||
|
"lib0",
|
||||||
|
"parking_lot 0.12.1",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"thiserror",
|
||||||
|
"tracing",
|
||||||
|
"y-sync",
|
||||||
|
"yrs",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "collab-client-ws"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"bytes",
|
||||||
|
"futures-util",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"thiserror",
|
||||||
|
"tokio",
|
||||||
|
"tokio-retry",
|
||||||
|
"tokio-stream",
|
||||||
|
"tokio-tungstenite",
|
||||||
|
"tracing",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "collab-persistence"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"bincode",
|
||||||
|
"chrono",
|
||||||
|
"lazy_static",
|
||||||
|
"lib0",
|
||||||
|
"parking_lot 0.12.1",
|
||||||
|
"rocksdb",
|
||||||
|
"serde",
|
||||||
|
"sled",
|
||||||
|
"smallvec",
|
||||||
|
"thiserror",
|
||||||
|
"tokio",
|
||||||
|
"tracing",
|
||||||
|
"yrs",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "collab-plugins"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"collab",
|
||||||
|
"collab-client-ws",
|
||||||
|
"collab-persistence",
|
||||||
|
"collab-sync",
|
||||||
|
"tracing",
|
||||||
|
"y-sync",
|
||||||
|
"yrs",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "collab-sync"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"bytes",
|
||||||
|
"collab",
|
||||||
|
"futures-util",
|
||||||
|
"lib0",
|
||||||
|
"md5",
|
||||||
|
"parking_lot 0.12.1",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"thiserror",
|
||||||
|
"tokio",
|
||||||
|
"tokio-stream",
|
||||||
|
"tokio-util",
|
||||||
|
"tracing",
|
||||||
|
"y-sync",
|
||||||
|
"yrs",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "combine"
|
name = "combine"
|
||||||
version = "4.6.6"
|
version = "4.6.6"
|
||||||
|
@ -793,7 +939,7 @@ dependencies = [
|
||||||
"hkdf",
|
"hkdf",
|
||||||
"hmac",
|
"hmac",
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
"rand",
|
"rand 0.8.5",
|
||||||
"sha2",
|
"sha2",
|
||||||
"subtle",
|
"subtle",
|
||||||
"time",
|
"time",
|
||||||
|
@ -914,7 +1060,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
|
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"generic-array",
|
"generic-array",
|
||||||
"rand_core",
|
"rand_core 0.6.4",
|
||||||
"typenum",
|
"typenum",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -1308,6 +1454,19 @@ dependencies = [
|
||||||
"winapi",
|
"winapi",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "getrandom"
|
||||||
|
version = "0.1.16"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"js-sys",
|
||||||
|
"libc",
|
||||||
|
"wasi 0.9.0+wasi-snapshot-preview1",
|
||||||
|
"wasm-bindgen",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "getrandom"
|
name = "getrandom"
|
||||||
version = "0.2.9"
|
version = "0.2.9"
|
||||||
|
@ -1316,7 +1475,7 @@ checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"libc",
|
"libc",
|
||||||
"wasi",
|
"wasi 0.11.0+wasi-snapshot-preview1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -1329,6 +1488,12 @@ dependencies = [
|
||||||
"polyval",
|
"polyval",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "glob"
|
||||||
|
version = "0.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "h2"
|
name = "h2"
|
||||||
version = "0.3.18"
|
version = "0.3.18"
|
||||||
|
@ -1635,12 +1800,65 @@ version = "1.4.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "lazycell"
|
||||||
|
version = "1.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "lib0"
|
||||||
|
version = "0.16.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "daf23122cb1c970b77ea6030eac5e328669415b65d2ab245c99bfb110f9d62dc"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"thiserror",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "libc"
|
name = "libc"
|
||||||
version = "0.2.142"
|
version = "0.2.142"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6a987beff54b60ffa6d51982e1aa1146bc42f19bd26be28b0586f252fccf5317"
|
checksum = "6a987beff54b60ffa6d51982e1aa1146bc42f19bd26be28b0586f252fccf5317"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "libloading"
|
||||||
|
version = "0.7.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b67380fd3b2fbe7527a606e18729d21c6f3951633d0500574c4dc22d2d638b9f"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"winapi",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "librocksdb-sys"
|
||||||
|
version = "0.10.0+7.9.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0fe4d5874f5ff2bc616e55e8c6086d478fcda13faf9495768a4aa1c22042d30b"
|
||||||
|
dependencies = [
|
||||||
|
"bindgen",
|
||||||
|
"bzip2-sys",
|
||||||
|
"cc",
|
||||||
|
"glob",
|
||||||
|
"libc",
|
||||||
|
"libz-sys",
|
||||||
|
"zstd-sys",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "libz-sys"
|
||||||
|
version = "1.1.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "56ee889ecc9568871456d42f603d6a0ce59ff328d291063a45cbdf0036baf6db"
|
||||||
|
dependencies = [
|
||||||
|
"cc",
|
||||||
|
"pkg-config",
|
||||||
|
"vcpkg",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "link-cplusplus"
|
name = "link-cplusplus"
|
||||||
version = "1.0.8"
|
version = "1.0.8"
|
||||||
|
@ -1723,6 +1941,12 @@ dependencies = [
|
||||||
"digest",
|
"digest",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "md5"
|
||||||
|
version = "0.7.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "memchr"
|
name = "memchr"
|
||||||
version = "2.5.0"
|
version = "2.5.0"
|
||||||
|
@ -1767,7 +1991,7 @@ checksum = "5b9d9a46eff5b4ff64b45a9e316a6d1e0bc719ef429cbec4dc630684212bfdf9"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
"log",
|
"log",
|
||||||
"wasi",
|
"wasi 0.11.0+wasi-snapshot-preview1",
|
||||||
"windows-sys 0.45.0",
|
"windows-sys 0.45.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -1975,7 +2199,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166"
|
checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"base64ct",
|
"base64ct",
|
||||||
"rand_core",
|
"rand_core 0.6.4",
|
||||||
"subtle",
|
"subtle",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -1991,6 +2215,12 @@ version = "0.2.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8835116a5c179084a830efb3adc117ab007512b535bc1a21c991d3b32a6b44dd"
|
checksum = "8835116a5c179084a830efb3adc117ab007512b535bc1a21c991d3b32a6b44dd"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "peeking_take_while"
|
||||||
|
version = "0.1.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pem"
|
name = "pem"
|
||||||
version = "1.1.1"
|
version = "1.1.1"
|
||||||
|
@ -2096,6 +2326,19 @@ dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rand"
|
||||||
|
version = "0.7.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03"
|
||||||
|
dependencies = [
|
||||||
|
"getrandom 0.1.16",
|
||||||
|
"libc",
|
||||||
|
"rand_chacha 0.2.2",
|
||||||
|
"rand_core 0.5.1",
|
||||||
|
"rand_hc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rand"
|
name = "rand"
|
||||||
version = "0.8.5"
|
version = "0.8.5"
|
||||||
|
@ -2103,8 +2346,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
|
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
"rand_chacha",
|
"rand_chacha 0.3.1",
|
||||||
"rand_core",
|
"rand_core 0.6.4",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rand_chacha"
|
||||||
|
version = "0.2.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402"
|
||||||
|
dependencies = [
|
||||||
|
"ppv-lite86",
|
||||||
|
"rand_core 0.5.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -2114,7 +2367,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
|
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"ppv-lite86",
|
"ppv-lite86",
|
||||||
"rand_core",
|
"rand_core 0.6.4",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rand_core"
|
||||||
|
version = "0.5.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
|
||||||
|
dependencies = [
|
||||||
|
"getrandom 0.1.16",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -2123,7 +2385,16 @@ version = "0.6.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
|
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"getrandom",
|
"getrandom 0.2.9",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rand_hc"
|
||||||
|
version = "0.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
|
||||||
|
dependencies = [
|
||||||
|
"rand_core 0.5.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -2186,7 +2457,7 @@ version = "0.4.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b"
|
checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"getrandom",
|
"getrandom 0.2.9",
|
||||||
"redox_syscall 0.2.16",
|
"redox_syscall 0.2.16",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
]
|
]
|
||||||
|
@ -2264,17 +2535,6 @@ dependencies = [
|
||||||
"winreg",
|
"winreg",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "revdb"
|
|
||||||
version = "0.1.0"
|
|
||||||
dependencies = [
|
|
||||||
"bincode",
|
|
||||||
"serde",
|
|
||||||
"sled",
|
|
||||||
"tempfile",
|
|
||||||
"thiserror",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ring"
|
name = "ring"
|
||||||
version = "0.16.20"
|
version = "0.16.20"
|
||||||
|
@ -2290,6 +2550,22 @@ dependencies = [
|
||||||
"winapi",
|
"winapi",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rocksdb"
|
||||||
|
version = "0.20.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "015439787fce1e75d55f279078d33ff14b4af5d93d995e8838ee4631301c8a99"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
"librocksdb-sys",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustc-hash"
|
||||||
|
version = "1.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustc_version"
|
name = "rustc_version"
|
||||||
version = "0.4.0"
|
version = "0.4.0"
|
||||||
|
@ -2504,6 +2780,12 @@ dependencies = [
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "shlex"
|
||||||
|
version = "1.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "43b2853a4d09f215c24cc5489c992ce46052d359b5109343cbafbf26bc62f8a3"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "signal-hook-registry"
|
name = "signal-hook-registry"
|
||||||
version = "1.4.1"
|
version = "1.4.1"
|
||||||
|
@ -2538,6 +2820,15 @@ dependencies = [
|
||||||
"parking_lot 0.11.2",
|
"parking_lot 0.11.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "smallstr"
|
||||||
|
version = "0.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1e922794d168678729ffc7e07182721a14219c65814e66e91b839a272fe5ae4f"
|
||||||
|
dependencies = [
|
||||||
|
"smallvec",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "smallvec"
|
name = "smallvec"
|
||||||
version = "1.10.0"
|
version = "1.10.0"
|
||||||
|
@ -2621,7 +2912,7 @@ dependencies = [
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"paste",
|
"paste",
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
"rand",
|
"rand 0.8.5",
|
||||||
"rustls",
|
"rustls",
|
||||||
"rustls-pemfile",
|
"rustls-pemfile",
|
||||||
"serde",
|
"serde",
|
||||||
|
@ -2881,6 +3172,17 @@ dependencies = [
|
||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokio-retry"
|
||||||
|
version = "0.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7f57eb36ecbe0fc510036adff84824dd3c24bb781e21bfa67b69d556aa85214f"
|
||||||
|
dependencies = [
|
||||||
|
"pin-project",
|
||||||
|
"rand 0.8.5",
|
||||||
|
"tokio",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokio-rustls"
|
name = "tokio-rustls"
|
||||||
version = "0.23.4"
|
version = "0.23.4"
|
||||||
|
@ -2901,6 +3203,19 @@ dependencies = [
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tokio-util",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokio-tungstenite"
|
||||||
|
version = "0.18.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "54319c93411147bced34cb5609a80e0a8e44c5999c93903a81cd866630ec0bfd"
|
||||||
|
dependencies = [
|
||||||
|
"futures-util",
|
||||||
|
"log",
|
||||||
|
"tokio",
|
||||||
|
"tungstenite",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -3035,6 +3350,25 @@ version = "0.2.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed"
|
checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tungstenite"
|
||||||
|
version = "0.18.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "30ee6ab729cd4cf0fd55218530c4522ed30b7b6081752839b68fcec8d0960788"
|
||||||
|
dependencies = [
|
||||||
|
"base64 0.13.1",
|
||||||
|
"byteorder",
|
||||||
|
"bytes",
|
||||||
|
"http",
|
||||||
|
"httparse",
|
||||||
|
"log",
|
||||||
|
"rand 0.8.5",
|
||||||
|
"sha1",
|
||||||
|
"thiserror",
|
||||||
|
"url",
|
||||||
|
"utf-8",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "typenum"
|
name = "typenum"
|
||||||
version = "1.16.0"
|
version = "1.16.0"
|
||||||
|
@ -3113,13 +3447,19 @@ dependencies = [
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "utf-8"
|
||||||
|
version = "0.7.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "uuid"
|
name = "uuid"
|
||||||
version = "1.3.2"
|
version = "1.3.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4dad5567ad0cf5b760e5665964bec1b47dfd077ba8a2544b513f3556d3d239a2"
|
checksum = "4dad5567ad0cf5b760e5665964bec1b47dfd077ba8a2544b513f3556d3d239a2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"getrandom",
|
"getrandom 0.2.9",
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -3166,6 +3506,12 @@ dependencies = [
|
||||||
"try-lock",
|
"try-lock",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wasi"
|
||||||
|
version = "0.9.0+wasi-snapshot-preview1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "wasi"
|
name = "wasi"
|
||||||
version = "0.11.0+wasi-snapshot-preview1"
|
version = "0.11.0+wasi-snapshot-preview1"
|
||||||
|
@ -3267,6 +3613,28 @@ dependencies = [
|
||||||
"webpki",
|
"webpki",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "websocket"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"actix",
|
||||||
|
"actix-web-actors",
|
||||||
|
"bytes",
|
||||||
|
"collab",
|
||||||
|
"collab-persistence",
|
||||||
|
"collab-plugins",
|
||||||
|
"collab-sync",
|
||||||
|
"dashmap",
|
||||||
|
"futures-util",
|
||||||
|
"parking_lot 0.12.1",
|
||||||
|
"secrecy",
|
||||||
|
"serde",
|
||||||
|
"thiserror",
|
||||||
|
"tokio",
|
||||||
|
"tokio-stream",
|
||||||
|
"tracing",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "whoami"
|
name = "whoami"
|
||||||
version = "1.4.0"
|
version = "1.4.0"
|
||||||
|
@ -3492,6 +3860,17 @@ dependencies = [
|
||||||
"time",
|
"time",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "y-sync"
|
||||||
|
version = "0.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f54d34b68ec4514a0659838c2b1ba867c571b20b3804a1338dacf4fa9062d801"
|
||||||
|
dependencies = [
|
||||||
|
"lib0",
|
||||||
|
"thiserror",
|
||||||
|
"yrs",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "yaml-rust"
|
name = "yaml-rust"
|
||||||
version = "0.4.5"
|
version = "0.4.5"
|
||||||
|
@ -3510,6 +3889,20 @@ dependencies = [
|
||||||
"time",
|
"time",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "yrs"
|
||||||
|
version = "0.16.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4c2aef2bf89b4f7c003f9c73f1c8097427ca32e1d006443f3f607f11e79a797b"
|
||||||
|
dependencies = [
|
||||||
|
"atomic_refcell",
|
||||||
|
"lib0",
|
||||||
|
"rand 0.7.3",
|
||||||
|
"smallstr",
|
||||||
|
"smallvec",
|
||||||
|
"thiserror",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zeroize"
|
name = "zeroize"
|
||||||
version = "1.6.0"
|
version = "1.6.0"
|
||||||
|
|
19
Cargo.toml
19
Cargo.toml
|
@ -51,6 +51,8 @@ bytes = "1.4.0"
|
||||||
bincode = "1.3.3"
|
bincode = "1.3.3"
|
||||||
dashmap = "5.4"
|
dashmap = "5.4"
|
||||||
rcgen = { version = "0.10.0", features = ["pem", "x509-parser"] }
|
rcgen = { version = "0.10.0", features = ["pem", "x509-parser"] }
|
||||||
|
collab-sync = {version = "0.1.0"}
|
||||||
|
collab-persistence = {version = "0.1.0"}
|
||||||
|
|
||||||
# tracing
|
# tracing
|
||||||
tracing = { version = "0.1.37" }
|
tracing = { version = "0.1.37" }
|
||||||
|
@ -63,9 +65,11 @@ sqlx = { version = "0.6", default-features = false, features = ["runtime-actix-r
|
||||||
#Local crate
|
#Local crate
|
||||||
token = { path = "./crates/token" }
|
token = { path = "./crates/token" }
|
||||||
snowflake = { path = "./crates/snowflake" }
|
snowflake = { path = "./crates/snowflake" }
|
||||||
|
websocket = { path = "./crates/websocket" }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
once_cell = "1.7.2"
|
once_cell = "1.7.2"
|
||||||
|
collab-client-ws = { version = "0.1.0" }
|
||||||
|
|
||||||
[[bin]]
|
[[bin]]
|
||||||
name = "appflowy_server"
|
name = "appflowy_server"
|
||||||
|
@ -78,6 +82,19 @@ path = "src/lib.rs"
|
||||||
[workspace]
|
[workspace]
|
||||||
members = [
|
members = [
|
||||||
"crates/token",
|
"crates/token",
|
||||||
"crates/revdb",
|
|
||||||
"crates/snowflake",
|
"crates/snowflake",
|
||||||
|
"crates/websocket",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[patch.crates-io]
|
||||||
|
collab = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a12ed" }
|
||||||
|
collab-client-ws = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a12ed" }
|
||||||
|
collab-sync= { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a12ed" }
|
||||||
|
collab-persistence = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a12ed" }
|
||||||
|
collab-plugins = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a12ed" }
|
||||||
|
|
||||||
|
#collab = { path = "./crates/AppFlowy-Collab/collab" }
|
||||||
|
#collab-client-ws = { path = "./crates/AppFlowy-Collab/collab-client-ws" }
|
||||||
|
#collab-sync = { path = "./crates/AppFlowy-Collab/collab-sync" }
|
||||||
|
#collab-persistence = { path = "./crates/AppFlowy-Collab/collab-persistence" }
|
||||||
|
#collab-plugins = { path = "./crates/AppFlowy-Collab/collab-plugins"}
|
||||||
|
|
|
@ -3,6 +3,7 @@ application:
|
||||||
host: 0.0.0.0
|
host: 0.0.0.0
|
||||||
server_key: "Should-Use-The-Custom-Server-Key"
|
server_key: "Should-Use-The-Custom-Server-Key"
|
||||||
tls_config: "no_tls"
|
tls_config: "no_tls"
|
||||||
|
data_dir: "./data"
|
||||||
database:
|
database:
|
||||||
host: "localhost"
|
host: "localhost"
|
||||||
port: 5432
|
port: 5432
|
||||||
|
|
|
@ -1,15 +0,0 @@
|
||||||
[package]
|
|
||||||
name = "revdb"
|
|
||||||
version = "0.1.0"
|
|
||||||
edition = "2021"
|
|
||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
sled = "0.34.7"
|
|
||||||
thiserror = "1.0.30"
|
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
|
||||||
bincode = "1.3.3"
|
|
||||||
|
|
||||||
[dev-dependencies]
|
|
||||||
tempfile = "3.4.0"
|
|
|
@ -1,56 +0,0 @@
|
||||||
use crate::document::Document;
|
|
||||||
use crate::error::RevDBError;
|
|
||||||
use sled::{Batch, Db, IVec};
|
|
||||||
use std::path::Path;
|
|
||||||
|
|
||||||
pub struct RevDB {
|
|
||||||
pub(crate) db: Db,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RevDB {
|
|
||||||
pub fn open(path: impl AsRef<Path>) -> Result<Self, RevDBError> {
|
|
||||||
let db = sled::open(path)?;
|
|
||||||
Ok(Self { db })
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn document(&self) -> Document {
|
|
||||||
Document { db: self }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get<K: AsRef<[u8]>>(&self, key: K) -> Result<Option<IVec>, RevDBError> {
|
|
||||||
let value = self.db.get(key)?;
|
|
||||||
Ok(value)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn batch_get<K: AsRef<[u8]>>(
|
|
||||||
&self,
|
|
||||||
from_key: K,
|
|
||||||
to_key: K,
|
|
||||||
) -> Result<Vec<IVec>, RevDBError> {
|
|
||||||
let iter = self.db.range(from_key..=to_key);
|
|
||||||
let mut items = vec![];
|
|
||||||
for item in iter {
|
|
||||||
let (_, value) = item?;
|
|
||||||
items.push(value)
|
|
||||||
}
|
|
||||||
Ok(items)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn insert<K: AsRef<[u8]>>(&self, key: K, value: &[u8]) -> Result<(), RevDBError> {
|
|
||||||
let _ = self.db.insert(key, value)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn batch_insert<'a, K: AsRef<[u8]>>(
|
|
||||||
&self,
|
|
||||||
items: impl IntoIterator<Item = (K, &'a [u8])>,
|
|
||||||
) -> Result<(), RevDBError> {
|
|
||||||
let mut batch = Batch::default();
|
|
||||||
let items = items.into_iter();
|
|
||||||
items.for_each(|(key, value)| {
|
|
||||||
batch.insert(key.as_ref(), value);
|
|
||||||
});
|
|
||||||
self.db.apply_batch(batch)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,117 +0,0 @@
|
||||||
use crate::db::RevDB;
|
|
||||||
use crate::error::RevDBError;
|
|
||||||
use crate::range::RevRange;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
pub struct Document<'a> {
|
|
||||||
pub(crate) db: &'a RevDB,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> Document<'a> {
|
|
||||||
pub fn insert(
|
|
||||||
&self,
|
|
||||||
uid: i64,
|
|
||||||
document_id: i64,
|
|
||||||
value: DocumentRevData,
|
|
||||||
) -> Result<(), RevDBError> {
|
|
||||||
let key = make_document_key(uid, document_id, value.rev_id);
|
|
||||||
self.db.insert(key, &value.to_vec()?)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get(
|
|
||||||
&self,
|
|
||||||
uid: i64,
|
|
||||||
document_id: i64,
|
|
||||||
rev_id: i64,
|
|
||||||
) -> Result<Option<DocumentRevData>, RevDBError> {
|
|
||||||
let key = make_document_key(uid, document_id, rev_id);
|
|
||||||
match self.db.get(key)? {
|
|
||||||
None => Ok(None),
|
|
||||||
Some(value) => {
|
|
||||||
let data = DocumentRevData::from_vec(value.as_ref())?;
|
|
||||||
Ok(Some(data))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_with_range<R: Into<RevRange>>(
|
|
||||||
&self,
|
|
||||||
uid: i64,
|
|
||||||
document_id: i64,
|
|
||||||
range: R,
|
|
||||||
) -> Result<Vec<DocumentRevData>, RevDBError> {
|
|
||||||
let range = range.into();
|
|
||||||
let from = make_document_key(uid, document_id, range.start);
|
|
||||||
let to = make_document_key(uid, document_id, range.end);
|
|
||||||
self.batch_get(from, to)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_after(
|
|
||||||
&self,
|
|
||||||
uid: i64,
|
|
||||||
document_id: i64,
|
|
||||||
rev_id: i64,
|
|
||||||
) -> Result<Vec<DocumentRevData>, RevDBError> {
|
|
||||||
let from = make_document_key(uid, document_id, rev_id);
|
|
||||||
let to = make_document_key(uid, document_id, i64::MAX);
|
|
||||||
self.batch_get(from, to)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_before(
|
|
||||||
&self,
|
|
||||||
uid: i64,
|
|
||||||
document_id: i64,
|
|
||||||
rev_id: i64,
|
|
||||||
) -> Result<Vec<DocumentRevData>, RevDBError> {
|
|
||||||
let from = make_document_key(uid, document_id, 0);
|
|
||||||
let to = make_document_key(uid, document_id, rev_id);
|
|
||||||
self.batch_get(from, to)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn batch_get<K: AsRef<[u8]>>(
|
|
||||||
&self,
|
|
||||||
from: K,
|
|
||||||
to: K,
|
|
||||||
) -> Result<Vec<DocumentRevData>, RevDBError> {
|
|
||||||
let items = self.db.batch_get(from, to)?;
|
|
||||||
let mut document_revs = vec![];
|
|
||||||
for item in items {
|
|
||||||
let rev_data = DocumentRevData::from_vec(item.as_ref())?;
|
|
||||||
document_revs.push(rev_data);
|
|
||||||
}
|
|
||||||
Ok(document_revs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct DocumentRevData {
|
|
||||||
#[serde(rename = "rid")]
|
|
||||||
pub rev_id: i64,
|
|
||||||
|
|
||||||
#[serde(rename = "bid")]
|
|
||||||
pub base_rev_id: i64,
|
|
||||||
|
|
||||||
#[serde(rename = "data")]
|
|
||||||
pub content: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DocumentRevData {
|
|
||||||
pub fn from_vec(data: &[u8]) -> Result<Self, RevDBError> {
|
|
||||||
bincode::deserialize::<Self>(data).map_err(|_e| RevDBError::SerdeError)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn to_vec(&self) -> Result<Vec<u8>, RevDBError> {
|
|
||||||
bincode::serialize(self).map_err(|_e| RevDBError::SerdeError)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Optimize your data layout: Sled's B-Tree implementation works best when the keys are sequential,
|
|
||||||
// so try to organize the data in a way that maximizes sequential access.
|
|
||||||
fn make_document_key(uid: i64, document_id: i64, rev_id: i64) -> [u8; 24] {
|
|
||||||
let mut key = [0; 24];
|
|
||||||
key[0..8].copy_from_slice(&uid.to_be_bytes());
|
|
||||||
key[8..16].copy_from_slice(&document_id.to_be_bytes());
|
|
||||||
key[16..24].copy_from_slice(&rev_id.to_be_bytes());
|
|
||||||
key
|
|
||||||
}
|
|
|
@ -1,11 +0,0 @@
|
||||||
#[derive(Debug, thiserror::Error)]
|
|
||||||
pub enum RevDBError {
|
|
||||||
#[error(transparent)]
|
|
||||||
Db(#[from] sled::Error),
|
|
||||||
|
|
||||||
#[error("Serde error")]
|
|
||||||
SerdeError,
|
|
||||||
|
|
||||||
#[error("invalid data")]
|
|
||||||
InvalidData,
|
|
||||||
}
|
|
|
@ -1,4 +0,0 @@
|
||||||
pub mod db;
|
|
||||||
pub mod document;
|
|
||||||
pub mod error;
|
|
||||||
pub mod range;
|
|
|
@ -1,48 +0,0 @@
|
||||||
use std::ops::{Range, RangeInclusive, RangeToInclusive};
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct RevRange {
|
|
||||||
pub(crate) start: i64,
|
|
||||||
pub(crate) end: i64,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RevRange {
|
|
||||||
/// Construct a new `RevRange` representing the range [start..end).
|
|
||||||
/// It is an invariant that `start <= end`.
|
|
||||||
pub fn new(start: i64, end: i64) -> RevRange {
|
|
||||||
debug_assert!(start <= end);
|
|
||||||
RevRange { start, end }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<RangeInclusive<i64>> for RevRange {
|
|
||||||
fn from(src: RangeInclusive<i64>) -> RevRange {
|
|
||||||
RevRange::new(*src.start(), src.end().saturating_add(1))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<RangeToInclusive<i64>> for RevRange {
|
|
||||||
fn from(src: RangeToInclusive<i64>) -> RevRange {
|
|
||||||
RevRange::new(0, src.end.saturating_add(1))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<Range<i64>> for RevRange {
|
|
||||||
fn from(src: Range<i64>) -> RevRange {
|
|
||||||
let Range { start, end } = src;
|
|
||||||
RevRange { start, end }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Iterator for RevRange {
|
|
||||||
type Item = i64;
|
|
||||||
|
|
||||||
fn next(&mut self) -> Option<i64> {
|
|
||||||
if self.start > self.end {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
let val = self.start;
|
|
||||||
self.start += 1;
|
|
||||||
Some(val)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,2 +0,0 @@
|
||||||
mod test;
|
|
||||||
mod util;
|
|
|
@ -1,136 +0,0 @@
|
||||||
use crate::document::util::make_test_db;
|
|
||||||
use revdb::document::{Document, DocumentRevData};
|
|
||||||
use revdb::range::RevRange;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn insert_text() {
|
|
||||||
let db = make_test_db();
|
|
||||||
let document = db.document();
|
|
||||||
let uid = 12345678;
|
|
||||||
let document_id = 1;
|
|
||||||
let value = DocumentRevData {
|
|
||||||
rev_id: 0,
|
|
||||||
base_rev_id: 0,
|
|
||||||
content: "hello world".to_string(),
|
|
||||||
};
|
|
||||||
document.insert(uid, document_id, value.clone()).unwrap();
|
|
||||||
|
|
||||||
let restored_data = document.get(uid, document_id, 0).unwrap().unwrap();
|
|
||||||
assert_eq!(value.content, restored_data.content);
|
|
||||||
}
|
|
||||||
|
|
||||||
//noinspection RsExternalLinter
|
|
||||||
#[test]
|
|
||||||
fn insert_multi_text() {
|
|
||||||
let db = make_test_db();
|
|
||||||
let document = db.document();
|
|
||||||
let uid = 12345678;
|
|
||||||
let document_id = 1;
|
|
||||||
|
|
||||||
let mut base_rev_id = 0;
|
|
||||||
let mut expected_str = "".to_string();
|
|
||||||
for i in 0..=100 {
|
|
||||||
let content = i.to_string();
|
|
||||||
expected_str.push_str(&content);
|
|
||||||
let value = DocumentRevData {
|
|
||||||
rev_id: i,
|
|
||||||
base_rev_id,
|
|
||||||
content,
|
|
||||||
};
|
|
||||||
base_rev_id += 1;
|
|
||||||
document.insert(uid, document_id, value).unwrap();
|
|
||||||
}
|
|
||||||
//
|
|
||||||
let restored_str = document
|
|
||||||
.get_with_range(uid, document_id, RevRange::new(0, 100))
|
|
||||||
.unwrap()
|
|
||||||
.into_iter()
|
|
||||||
.map(|data| data.content)
|
|
||||||
.collect::<Vec<String>>()
|
|
||||||
.join("");
|
|
||||||
|
|
||||||
assert_eq!(expected_str, restored_str);
|
|
||||||
}
|
|
||||||
|
|
||||||
//noinspection RsExternalLinter
|
|
||||||
fn insert_100_string_to_document(uid: i64, document_id: i64, document: &Document) {
|
|
||||||
let mut base_rev_id = 0;
|
|
||||||
for i in 0..=100 {
|
|
||||||
let content = i.to_string();
|
|
||||||
let value = DocumentRevData {
|
|
||||||
rev_id: i,
|
|
||||||
base_rev_id,
|
|
||||||
content,
|
|
||||||
};
|
|
||||||
base_rev_id += 1;
|
|
||||||
document.insert(uid, document_id, value).unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn values_to_string(values: Vec<DocumentRevData>) -> String {
|
|
||||||
values
|
|
||||||
.into_iter()
|
|
||||||
.map(|data| data.content)
|
|
||||||
.collect::<Vec<String>>()
|
|
||||||
.join("")
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn get_value_before() {
|
|
||||||
let db = make_test_db();
|
|
||||||
let document = db.document();
|
|
||||||
let uid = 12345678;
|
|
||||||
let document_id = 1;
|
|
||||||
insert_100_string_to_document(uid, document_id, &document);
|
|
||||||
|
|
||||||
let restored_str = values_to_string(document.get_before(uid, document_id, 50).unwrap());
|
|
||||||
assert_eq!("01234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950", restored_str);
|
|
||||||
|
|
||||||
let restored_str = values_to_string(document.get_before(uid, document_id, 0).unwrap());
|
|
||||||
assert_eq!("0", restored_str);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn get_value_after() {
|
|
||||||
let db = make_test_db();
|
|
||||||
let document = db.document();
|
|
||||||
let uid = 12345678;
|
|
||||||
let document_id = 1;
|
|
||||||
insert_100_string_to_document(uid, document_id, &document);
|
|
||||||
|
|
||||||
let restored_str = values_to_string(document.get_after(uid, document_id, 50).unwrap());
|
|
||||||
assert_eq!("5051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100", restored_str);
|
|
||||||
|
|
||||||
let restored_str = values_to_string(document.get_after(uid, document_id, 100).unwrap());
|
|
||||||
assert_eq!("100", restored_str);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn get_value_with_range() {
|
|
||||||
let db = make_test_db();
|
|
||||||
let document = db.document();
|
|
||||||
let uid = 12345678;
|
|
||||||
let document_id = 1;
|
|
||||||
insert_100_string_to_document(uid, document_id, &document);
|
|
||||||
|
|
||||||
let restored_str = values_to_string(
|
|
||||||
document
|
|
||||||
.get_with_range(uid, document_id, RevRange::new(50, 60))
|
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
assert_eq!("5051525354555657585960", restored_str);
|
|
||||||
|
|
||||||
let restored_str = values_to_string(
|
|
||||||
document
|
|
||||||
.get_with_range(uid, document_id, RevRange::new(50, 200))
|
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
assert_eq!("5051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100", restored_str);
|
|
||||||
|
|
||||||
let restored_str = values_to_string(
|
|
||||||
document
|
|
||||||
.get_with_range(uid, document_id, RevRange::new(50, 50))
|
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
assert_eq!("50", restored_str);
|
|
||||||
}
|
|
|
@ -1,7 +0,0 @@
|
||||||
use revdb::db::RevDB;
|
|
||||||
use tempfile::TempDir;
|
|
||||||
|
|
||||||
pub fn make_test_db() -> RevDB {
|
|
||||||
let tempdir = TempDir::new().unwrap();
|
|
||||||
RevDB::open(tempdir).unwrap()
|
|
||||||
}
|
|
|
@ -1 +0,0 @@
|
||||||
mod document;
|
|
|
@ -8,66 +8,65 @@ const TIMESTAMP_SHIFT: u64 = NODE_ID_BITS + SEQUENCE_BITS;
|
||||||
const SEQUENCE_MASK: u64 = (1 << SEQUENCE_BITS) - 1;
|
const SEQUENCE_MASK: u64 = (1 << SEQUENCE_BITS) - 1;
|
||||||
|
|
||||||
pub struct Snowflake {
|
pub struct Snowflake {
|
||||||
node_id: u64,
|
node_id: u64,
|
||||||
sequence: u64,
|
sequence: u64,
|
||||||
last_timestamp: u64,
|
last_timestamp: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Snowflake {
|
impl Snowflake {
|
||||||
pub fn new(node_id: u64) -> Snowflake {
|
pub fn new(node_id: u64) -> Snowflake {
|
||||||
Snowflake {
|
Snowflake {
|
||||||
node_id,
|
node_id,
|
||||||
sequence: 0,
|
sequence: 0,
|
||||||
last_timestamp: 0,
|
last_timestamp: 0,
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn next_id(&mut self) -> i64 {
|
||||||
|
let timestamp = self.timestamp();
|
||||||
|
if timestamp < self.last_timestamp {
|
||||||
|
panic!("Clock moved backwards!");
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn next_id(&mut self) -> i64 {
|
if timestamp == self.last_timestamp {
|
||||||
let timestamp = self.timestamp();
|
self.sequence = (self.sequence + 1) & SEQUENCE_MASK;
|
||||||
if timestamp < self.last_timestamp {
|
if self.sequence == 0 {
|
||||||
panic!("Clock moved backwards!");
|
self.wait_next_millis();
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
if timestamp == self.last_timestamp {
|
self.sequence = 0;
|
||||||
self.sequence = (self.sequence + 1) & SEQUENCE_MASK;
|
|
||||||
if self.sequence == 0 {
|
|
||||||
self.wait_next_millis();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
self.sequence = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
self.last_timestamp = timestamp;
|
|
||||||
let id =
|
|
||||||
(timestamp - EPOCH) << TIMESTAMP_SHIFT | self.node_id << NODE_ID_SHIFT | self.sequence;
|
|
||||||
id as i64
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn wait_next_millis(&self) {
|
self.last_timestamp = timestamp;
|
||||||
let mut timestamp = self.timestamp();
|
let id = (timestamp - EPOCH) << TIMESTAMP_SHIFT | self.node_id << NODE_ID_SHIFT | self.sequence;
|
||||||
while timestamp == self.last_timestamp {
|
id as i64
|
||||||
timestamp = self.timestamp();
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn timestamp(&self) -> u64 {
|
fn wait_next_millis(&self) {
|
||||||
SystemTime::now()
|
let mut timestamp = self.timestamp();
|
||||||
.duration_since(SystemTime::UNIX_EPOCH)
|
while timestamp == self.last_timestamp {
|
||||||
.expect("Clock moved backwards!")
|
timestamp = self.timestamp();
|
||||||
.as_millis() as u64
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn timestamp(&self) -> u64 {
|
||||||
|
SystemTime::now()
|
||||||
|
.duration_since(SystemTime::UNIX_EPOCH)
|
||||||
|
.expect("Clock moved backwards!")
|
||||||
|
.as_millis() as u64
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::Snowflake;
|
use crate::Snowflake;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn gen_id() {
|
fn gen_id() {
|
||||||
let mut snow_flake = Snowflake::new(1);
|
let mut snow_flake = Snowflake::new(1);
|
||||||
let id_1 = snow_flake.next_id();
|
let id_1 = snow_flake.next_id();
|
||||||
let id_2 = snow_flake.next_id();
|
let id_2 = snow_flake.next_id();
|
||||||
|
|
||||||
assert_ne!(id_1, id_2);
|
assert_ne!(id_1, id_2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,70 +7,72 @@ use sha2::Sha256;
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
pub enum TokenError {
|
pub enum TokenError {
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Jwt(#[from] jwt::Error),
|
Jwt(#[from] jwt::Error),
|
||||||
|
|
||||||
#[error("Token expired")]
|
#[error("Token expired")]
|
||||||
Expired,
|
Expired,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
|
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||||
pub enum TokenType {
|
pub enum TokenType {
|
||||||
AccessToken,
|
AccessToken,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
struct TokenFields<T> {
|
struct TokenFields<T> {
|
||||||
#[serde(rename = "d")]
|
#[serde(rename = "d")]
|
||||||
data: T,
|
data: T,
|
||||||
|
|
||||||
#[serde(rename = "exp")]
|
#[serde(rename = "exp")]
|
||||||
expire_at: DateTime<Utc>,
|
expire_at: DateTime<Utc>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn create_token(
|
pub fn create_token(
|
||||||
server_key: &str,
|
server_key: &str,
|
||||||
data: impl Serialize,
|
data: impl Serialize,
|
||||||
expire_duration: Duration,
|
expire_duration: Duration,
|
||||||
) -> Result<String, TokenError> {
|
) -> Result<String, TokenError> {
|
||||||
Ok(TokenFields {
|
Ok(
|
||||||
data,
|
TokenFields {
|
||||||
expire_at: Utc::now() + expire_duration,
|
data,
|
||||||
|
expire_at: Utc::now() + expire_duration,
|
||||||
}
|
}
|
||||||
.sign_with_key(&generate_hmac_key(server_key))?)
|
.sign_with_key(&generate_hmac_key(server_key))?,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn generate_hmac_key(server_key: &str) -> Hmac<Sha256> {
|
fn generate_hmac_key(server_key: &str) -> Hmac<Sha256> {
|
||||||
Hmac::<Sha256>::new_from_slice(server_key.as_bytes()).expect("invalid server key")
|
Hmac::<Sha256>::new_from_slice(server_key.as_bytes()).expect("invalid server key")
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn parse_token<T: DeserializeOwned>(server_key: &str, token: &str) -> Result<T, TokenError> {
|
pub fn parse_token<T: DeserializeOwned>(server_key: &str, token: &str) -> Result<T, TokenError> {
|
||||||
let fields =
|
let fields =
|
||||||
VerifyWithKey::<TokenFields<T>>::verify_with_key(token, &generate_hmac_key(server_key))?;
|
VerifyWithKey::<TokenFields<T>>::verify_with_key(token, &generate_hmac_key(server_key))?;
|
||||||
if fields.expire_at < Utc::now() {
|
if fields.expire_at < Utc::now() {
|
||||||
return Err(TokenError::Expired);
|
return Err(TokenError::Expired);
|
||||||
}
|
}
|
||||||
Ok(fields.data)
|
Ok(fields.data)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn create_token_test() {
|
fn create_token_test() {
|
||||||
let token_data = "hello appflowy".to_string();
|
let token_data = "hello appflowy".to_string();
|
||||||
let token = create_token("server_key", &token_data, Duration::days(2)).unwrap();
|
let token = create_token("server_key", &token_data, Duration::days(2)).unwrap();
|
||||||
|
|
||||||
let parse_token_data = parse_token::<String>("server_key", &token).unwrap();
|
let parse_token_data = parse_token::<String>("server_key", &token).unwrap();
|
||||||
assert_eq!(token_data, parse_token_data);
|
assert_eq!(token_data, parse_token_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[should_panic]
|
#[should_panic]
|
||||||
fn parser_token_with_different_server_key() {
|
fn parser_token_with_different_server_key() {
|
||||||
let server_key = "123456";
|
let server_key = "123456";
|
||||||
let token = create_token(server_key, "hello", Duration::days(2)).unwrap();
|
let token = create_token(server_key, "hello", Duration::days(2)).unwrap();
|
||||||
let _ = parse_token::<String>("abcdef", &token).unwrap();
|
let _ = parse_token::<String>("abcdef", &token).unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
25
crates/websocket/Cargo.toml
Normal file
25
crates/websocket/Cargo.toml
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
[package]
|
||||||
|
name = "websocket"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
actix = "0.13"
|
||||||
|
actix-web-actors = { version = "4.2.0" }
|
||||||
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
|
thiserror = "1.0.30"
|
||||||
|
bytes = "1.0"
|
||||||
|
secrecy = { version = "0.8", features = ["serde"] }
|
||||||
|
parking_lot = "0.12.1"
|
||||||
|
tracing = "0.1.25"
|
||||||
|
futures-util = "0.3.26"
|
||||||
|
tokio-stream = { version = "0.1.14", features = ["sync"] }
|
||||||
|
tokio = { version = "1.26", features = ["sync"] }
|
||||||
|
dashmap = "5.4.0"
|
||||||
|
|
||||||
|
collab = { version = "0.1.0"}
|
||||||
|
collab-sync = { version = "0.1.0"}
|
||||||
|
collab-persistence = { version = "0.1.0"}
|
||||||
|
collab-plugins = { version = "0.1.0", features = ["disk_rocksdb"]}
|
168
crates/websocket/src/client.rs
Normal file
168
crates/websocket/src/client.rs
Normal file
|
@ -0,0 +1,168 @@
|
||||||
|
use crate::entities::{ClientMessage, Connect, Disconnect, ServerMessage, WSUser};
|
||||||
|
use crate::error::WSError;
|
||||||
|
use crate::CollabServer;
|
||||||
|
use actix::{
|
||||||
|
fut, Actor, ActorContext, ActorFutureExt, Addr, AsyncContext, ContextFutureSpawner, Handler,
|
||||||
|
Recipient, Running, StreamHandler, WrapFuture,
|
||||||
|
};
|
||||||
|
use actix_web_actors::ws;
|
||||||
|
use bytes::Bytes;
|
||||||
|
|
||||||
|
use collab_sync::msg::CollabMessage;
|
||||||
|
use futures_util::Sink;
|
||||||
|
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
|
||||||
|
const CLIENT_TIMEOUT: Duration = Duration::from_secs(10);
|
||||||
|
|
||||||
|
pub struct CollabSession {
|
||||||
|
user: Arc<WSUser>,
|
||||||
|
hb: Instant,
|
||||||
|
pub server: Addr<CollabServer>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CollabSession {
|
||||||
|
pub fn new(user: WSUser, server: Addr<CollabServer>) -> Self {
|
||||||
|
Self {
|
||||||
|
user: Arc::new(user),
|
||||||
|
hb: Instant::now(),
|
||||||
|
server,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn hb(&self, ctx: &mut ws::WebsocketContext<Self>) {
|
||||||
|
ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| {
|
||||||
|
if Instant::now().duration_since(act.hb) > CLIENT_TIMEOUT {
|
||||||
|
act.server.do_send(Disconnect {
|
||||||
|
user: act.user.clone(),
|
||||||
|
});
|
||||||
|
ctx.stop();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.ping(b"");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
fn send_to_server(&self, bytes: Bytes) {
|
||||||
|
match CollabMessage::from_vec(bytes.to_vec()) {
|
||||||
|
Ok(collab_msg) => {
|
||||||
|
self.server.do_send(ClientMessage {
|
||||||
|
user: self.user.clone(),
|
||||||
|
collab_msg,
|
||||||
|
});
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Error parsing message: {:?}", e);
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Actor for CollabSession {
|
||||||
|
type Context = ws::WebsocketContext<Self>;
|
||||||
|
|
||||||
|
fn started(&mut self, ctx: &mut Self::Context) {
|
||||||
|
// start heartbeats otherwise server disconnects in 10 seconds
|
||||||
|
self.hb(ctx);
|
||||||
|
|
||||||
|
self
|
||||||
|
.server
|
||||||
|
.send(Connect {
|
||||||
|
socket: ctx.address().recipient(),
|
||||||
|
user: self.user.clone(),
|
||||||
|
})
|
||||||
|
.into_actor(self)
|
||||||
|
.then(|res, _session, ctx| {
|
||||||
|
match res {
|
||||||
|
Ok(Ok(_)) => {
|
||||||
|
tracing::trace!("Send connect message to server success")
|
||||||
|
},
|
||||||
|
_ => {
|
||||||
|
tracing::error!("Send connect message to server failed");
|
||||||
|
ctx.stop();
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fut::ready(())
|
||||||
|
})
|
||||||
|
.wait(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stopping(&mut self, _: &mut Self::Context) -> Running {
|
||||||
|
self.server.do_send(Disconnect {
|
||||||
|
user: self.user.clone(),
|
||||||
|
});
|
||||||
|
Running::Stop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Handler<ServerMessage> for CollabSession {
|
||||||
|
type Result = ();
|
||||||
|
|
||||||
|
fn handle(&mut self, msg: ServerMessage, ctx: &mut Self::Context) {
|
||||||
|
ctx.binary(msg.collab_msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// WebSocket message handler
|
||||||
|
impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for CollabSession {
|
||||||
|
fn handle(&mut self, msg: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
|
||||||
|
let msg = match msg {
|
||||||
|
Err(_) => {
|
||||||
|
ctx.stop();
|
||||||
|
return;
|
||||||
|
},
|
||||||
|
Ok(msg) => msg,
|
||||||
|
};
|
||||||
|
|
||||||
|
match msg {
|
||||||
|
ws::Message::Ping(msg) => {
|
||||||
|
self.hb = Instant::now();
|
||||||
|
ctx.pong(&msg);
|
||||||
|
},
|
||||||
|
ws::Message::Pong(_) => {
|
||||||
|
self.hb = Instant::now();
|
||||||
|
},
|
||||||
|
ws::Message::Text(_) => {},
|
||||||
|
ws::Message::Binary(bytes) => {
|
||||||
|
self.send_to_server(bytes);
|
||||||
|
},
|
||||||
|
ws::Message::Close(reason) => {
|
||||||
|
ctx.close(reason);
|
||||||
|
ctx.stop();
|
||||||
|
},
|
||||||
|
ws::Message::Continuation(_) => {
|
||||||
|
ctx.stop();
|
||||||
|
},
|
||||||
|
ws::Message::Nop => (),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A helper struct that wraps the [Recipient] type to implement the [Sink] trait
|
||||||
|
pub struct ClientSink(pub Recipient<ServerMessage>);
|
||||||
|
|
||||||
|
impl Sink<CollabMessage> for ClientSink {
|
||||||
|
type Error = WSError;
|
||||||
|
|
||||||
|
fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn start_send(self: Pin<&mut Self>, item: CollabMessage) -> Result<(), Self::Error> {
|
||||||
|
self.0.do_send(ServerMessage { collab_msg: item });
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
56
crates/websocket/src/entities.rs
Normal file
56
crates/websocket/src/entities.rs
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
use crate::error::WSError;
|
||||||
|
use actix::{Message, Recipient};
|
||||||
|
|
||||||
|
use collab_sync::msg::CollabMessage;
|
||||||
|
use secrecy::{ExposeSecret, Secret};
|
||||||
|
|
||||||
|
use std::hash::{Hash, Hasher};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct WSUser {
|
||||||
|
pub user_id: Secret<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Hash for WSUser {
|
||||||
|
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||||
|
let uid: &String = self.user_id.expose_secret();
|
||||||
|
uid.hash(state);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialEq<Self> for WSUser {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
let uid: &String = self.user_id.expose_secret();
|
||||||
|
let other_uid: &String = other.user_id.expose_secret();
|
||||||
|
uid == other_uid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Eq for WSUser {}
|
||||||
|
|
||||||
|
#[derive(Debug, Message, Clone)]
|
||||||
|
#[rtype(result = "Result<(), WSError>")]
|
||||||
|
pub struct Connect {
|
||||||
|
pub socket: Recipient<ServerMessage>,
|
||||||
|
pub user: Arc<WSUser>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Message, Clone)]
|
||||||
|
#[rtype(result = "Result<(), WSError>")]
|
||||||
|
pub struct Disconnect {
|
||||||
|
pub user: Arc<WSUser>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Message, Clone)]
|
||||||
|
#[rtype(result = "()")]
|
||||||
|
pub struct ClientMessage {
|
||||||
|
pub user: Arc<WSUser>,
|
||||||
|
pub collab_msg: CollabMessage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Message, Clone)]
|
||||||
|
#[rtype(result = "()")]
|
||||||
|
pub struct ServerMessage {
|
||||||
|
pub collab_msg: CollabMessage,
|
||||||
|
}
|
8
crates/websocket/src/error.rs
Normal file
8
crates/websocket/src/error.rs
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum WSError {
|
||||||
|
#[error(transparent)]
|
||||||
|
Persistence(#[from] collab_persistence::error::PersistenceError),
|
||||||
|
|
||||||
|
#[error("Internal failure: {0}")]
|
||||||
|
Internal(#[from] Box<dyn std::error::Error + Send + Sync>),
|
||||||
|
}
|
7
crates/websocket/src/lib.rs
Normal file
7
crates/websocket/src/lib.rs
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
mod client;
|
||||||
|
pub mod entities;
|
||||||
|
mod error;
|
||||||
|
mod server;
|
||||||
|
|
||||||
|
pub use client::*;
|
||||||
|
pub use server::*;
|
190
crates/websocket/src/server.rs
Normal file
190
crates/websocket/src/server.rs
Normal file
|
@ -0,0 +1,190 @@
|
||||||
|
use crate::entities::{ClientMessage, Connect, Disconnect, WSUser};
|
||||||
|
use crate::error::WSError;
|
||||||
|
use crate::ClientSink;
|
||||||
|
|
||||||
|
use actix::{Actor, Context, Handler, ResponseFuture};
|
||||||
|
use collab::core::collab::MutexCollab;
|
||||||
|
use collab::core::origin::CollabOrigin;
|
||||||
|
use collab_persistence::kv::rocks_kv::RocksCollabDB;
|
||||||
|
use collab_persistence::kv::KVStore;
|
||||||
|
use collab_plugins::disk_plugin::rocksdb_server::RocksdbServerDiskPlugin;
|
||||||
|
use collab_sync::server::{
|
||||||
|
CollabBroadcast, CollabGroup, CollabIDGen, CollabId, NonZeroNodeId, COLLAB_ID_LEN,
|
||||||
|
};
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use parking_lot::{Mutex, RwLock};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use collab_persistence::keys::make_collab_id_key;
|
||||||
|
use collab_sync::msg::CollabMessage;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::mpsc::Sender;
|
||||||
|
use tokio_stream::wrappers::ReceiverStream;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct CollabServer {
|
||||||
|
db: Arc<RocksCollabDB>,
|
||||||
|
/// Generate collab_id for new collab object
|
||||||
|
collab_id_gen: Arc<Mutex<CollabIDGen>>,
|
||||||
|
/// Memory cache for fast lookup of collab_id from object_id
|
||||||
|
collab_id_by_object_id: Arc<DashMap<String, CollabId>>,
|
||||||
|
collab_groups: Arc<RwLock<HashMap<CollabId, CollabGroup>>>,
|
||||||
|
client_streams: Arc<RwLock<HashMap<Arc<WSUser>, ClientStream>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CollabServer {
|
||||||
|
pub fn new(db: Arc<RocksCollabDB>) -> Result<Self, WSError> {
|
||||||
|
let collab_id_gen = Arc::new(Mutex::new(CollabIDGen::new(NonZeroNodeId(1))));
|
||||||
|
let collab_id_by_object_id = Arc::new(DashMap::new());
|
||||||
|
Ok(Self {
|
||||||
|
db,
|
||||||
|
collab_id_gen,
|
||||||
|
collab_id_by_object_id,
|
||||||
|
collab_groups: Default::default(),
|
||||||
|
client_streams: Default::default(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_collab_id(&self, object_id: &str) -> Result<CollabId, WSError> {
|
||||||
|
let collab_id = self.collab_id_gen.lock().next_id();
|
||||||
|
let collab_key = make_collab_id_key(object_id.as_ref());
|
||||||
|
self.db.with_write_txn(|w_txn| {
|
||||||
|
w_txn.insert(collab_key.as_ref(), collab_id.to_be_bytes())?;
|
||||||
|
Ok(())
|
||||||
|
})?;
|
||||||
|
Ok(collab_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_collab_id(&self, object_id: &str) -> Option<CollabId> {
|
||||||
|
let collab_key = make_collab_id_key(object_id.as_ref());
|
||||||
|
let read_txn = self.db.read_txn();
|
||||||
|
let value = read_txn.get(collab_key.as_ref()).ok()??;
|
||||||
|
|
||||||
|
let mut bytes = [0; COLLAB_ID_LEN];
|
||||||
|
bytes[0..COLLAB_ID_LEN].copy_from_slice(value.as_ref());
|
||||||
|
Some(CollabId::from_be_bytes(bytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_or_create_collab_id(&self, object_id: &str) -> Result<CollabId, WSError> {
|
||||||
|
let collab_id = self.get_collab_id(object_id);
|
||||||
|
if let Some(collab_id) = collab_id {
|
||||||
|
self.create_group_if_need(collab_id, object_id);
|
||||||
|
Ok(collab_id)
|
||||||
|
} else {
|
||||||
|
let collab_id = self.create_collab_id(object_id)?;
|
||||||
|
self
|
||||||
|
.collab_id_by_object_id
|
||||||
|
.insert(object_id.to_string(), collab_id);
|
||||||
|
self.create_group_if_need(collab_id, object_id);
|
||||||
|
Ok(collab_id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_group_if_need(&self, collab_id: CollabId, object_id: &str) {
|
||||||
|
if self.collab_groups.read().contains_key(&collab_id) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let collab = MutexCollab::new(CollabOrigin::Empty, object_id, vec![]);
|
||||||
|
let plugin = RocksdbServerDiskPlugin::new(collab_id, self.db.clone()).unwrap();
|
||||||
|
collab.lock().add_plugin(Arc::new(plugin));
|
||||||
|
collab.initial();
|
||||||
|
|
||||||
|
let broadcast = CollabBroadcast::new(object_id, collab.clone(), 10);
|
||||||
|
let group = CollabGroup {
|
||||||
|
collab,
|
||||||
|
broadcast,
|
||||||
|
subscribers: Default::default(),
|
||||||
|
};
|
||||||
|
self.collab_groups.write().insert(collab_id, group);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Actor for CollabServer {
|
||||||
|
type Context = Context<Self>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Handler<Connect> for CollabServer {
|
||||||
|
type Result = Result<(), WSError>;
|
||||||
|
|
||||||
|
fn handle(&mut self, msg: Connect, _ctx: &mut Context<Self>) -> Self::Result {
|
||||||
|
let (stream_tx, rx) = tokio::sync::mpsc::channel(100);
|
||||||
|
let stream = ClientStream::new(ClientSink(msg.socket), ReceiverStream::new(rx), stream_tx);
|
||||||
|
self.client_streams.write().insert(msg.user, stream);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Handler<Disconnect> for CollabServer {
|
||||||
|
type Result = Result<(), WSError>;
|
||||||
|
fn handle(&mut self, msg: Disconnect, _: &mut Context<Self>) -> Self::Result {
|
||||||
|
self.client_streams.write().remove(&msg.user);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Handler<ClientMessage> for CollabServer {
|
||||||
|
type Result = ResponseFuture<()>;
|
||||||
|
|
||||||
|
fn handle(&mut self, msg: ClientMessage, _ctx: &mut Context<Self>) -> Self::Result {
|
||||||
|
let object_id = msg.collab_msg.object_id();
|
||||||
|
if let Ok(collab_id) = self.get_or_create_collab_id(object_id) {
|
||||||
|
if let Some(collab_group) = self.collab_groups.write().get_mut(&collab_id) {
|
||||||
|
if let Some(stream) = self.client_streams.write().get_mut(&msg.user) {
|
||||||
|
if let Some((sink, stream)) = stream.split() {
|
||||||
|
let origin = match msg.collab_msg.origin() {
|
||||||
|
None => CollabOrigin::Empty,
|
||||||
|
Some(client) => client.clone(),
|
||||||
|
};
|
||||||
|
let sub = collab_group
|
||||||
|
.broadcast
|
||||||
|
.subscribe(origin.clone(), sink, stream);
|
||||||
|
collab_group.subscribers.insert(origin, sub);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let client_streams = self.client_streams.clone();
|
||||||
|
Box::pin(async move {
|
||||||
|
if let Some(client_stream) = client_streams.read().get(&msg.user) {
|
||||||
|
let _ = client_stream.stream_tx.send(Ok(msg.collab_msg)).await;
|
||||||
|
}
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Box::pin(async move {})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl actix::Supervised for CollabServer {
|
||||||
|
fn restarting(&mut self, _ctx: &mut Context<CollabServer>) {
|
||||||
|
tracing::warn!("restarting");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ClientStream {
|
||||||
|
sink: Option<ClientSink>,
|
||||||
|
stream: Option<ReceiverStream<Result<CollabMessage, WSError>>>,
|
||||||
|
stream_tx: Sender<Result<CollabMessage, WSError>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClientStream {
|
||||||
|
pub fn new(
|
||||||
|
sink: ClientSink,
|
||||||
|
stream: ReceiverStream<Result<CollabMessage, WSError>>,
|
||||||
|
stream_tx: Sender<Result<CollabMessage, WSError>>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
sink: Some(sink),
|
||||||
|
stream: Some(stream),
|
||||||
|
stream_tx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn split(&mut self) -> Option<(ClientSink, ReceiverStream<Result<CollabMessage, WSError>>)> {
|
||||||
|
let sink = self.sink.take()?;
|
||||||
|
let stream = self.stream.take()?;
|
||||||
|
Some((sink, stream))
|
||||||
|
}
|
||||||
|
}
|
12
rustfmt.toml
Normal file
12
rustfmt.toml
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
# https://rust-lang.github.io/rustfmt/?version=master&search=
|
||||||
|
max_width = 100
|
||||||
|
tab_spaces = 2
|
||||||
|
newline_style = "Auto"
|
||||||
|
match_block_trailing_comma = true
|
||||||
|
use_field_init_shorthand = true
|
||||||
|
use_try_shorthand = true
|
||||||
|
reorder_imports = true
|
||||||
|
reorder_modules = true
|
||||||
|
remove_nested_parens = true
|
||||||
|
merge_derives = true
|
||||||
|
edition = "2021"
|
120
src/api/user.rs
120
src/api/user.rs
|
@ -1,6 +1,6 @@
|
||||||
use crate::component::auth::{
|
use crate::component::auth::{
|
||||||
change_password, logged_user_from_request, login, logout, register, ChangePasswordRequest,
|
change_password, logged_user_from_request, login, logout, register, ChangePasswordRequest,
|
||||||
InputParamsError, LoginRequest, RegisterRequest,
|
InputParamsError, LoginRequest, RegisterRequest,
|
||||||
};
|
};
|
||||||
use crate::component::token_state::SessionToken;
|
use crate::component::token_state::SessionToken;
|
||||||
use crate::domain::{UserEmail, UserName, UserPassword};
|
use crate::domain::{UserEmail, UserName, UserPassword};
|
||||||
|
@ -11,83 +11,83 @@ use actix_web::{web, HttpResponse, Scope};
|
||||||
use actix_web::{HttpRequest, Result};
|
use actix_web::{HttpRequest, Result};
|
||||||
|
|
||||||
pub fn user_scope() -> Scope {
|
pub fn user_scope() -> Scope {
|
||||||
web::scope("/api/user")
|
web::scope("/api/user")
|
||||||
.service(web::resource("/login").route(web::post().to(login_handler)))
|
.service(web::resource("/login").route(web::post().to(login_handler)))
|
||||||
.service(web::resource("/logout").route(web::get().to(logout_handler)))
|
.service(web::resource("/logout").route(web::get().to(logout_handler)))
|
||||||
.service(web::resource("/register").route(web::post().to(register_handler)))
|
.service(web::resource("/register").route(web::post().to(register_handler)))
|
||||||
.service(web::resource("/password").route(web::post().to(change_password_handler)))
|
.service(web::resource("/password").route(web::post().to(change_password_handler)))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn login_handler(
|
async fn login_handler(
|
||||||
req: Json<LoginRequest>,
|
req: Json<LoginRequest>,
|
||||||
state: Data<State>,
|
state: Data<State>,
|
||||||
session: SessionToken,
|
session: SessionToken,
|
||||||
) -> Result<HttpResponse> {
|
) -> Result<HttpResponse> {
|
||||||
let req = req.into_inner();
|
let req = req.into_inner();
|
||||||
let email = UserEmail::parse(req.email)
|
let email = UserEmail::parse(req.email)
|
||||||
.map_err(InputParamsError::InvalidEmail)?
|
.map_err(InputParamsError::InvalidEmail)?
|
||||||
.0;
|
.0;
|
||||||
let password = UserPassword::parse(req.password)
|
let password = UserPassword::parse(req.password)
|
||||||
.map_err(|_| InputParamsError::InvalidPassword)?
|
.map_err(|_| InputParamsError::InvalidPassword)?
|
||||||
.0;
|
.0;
|
||||||
let (resp, token) = login(email, password, &state).await?;
|
let (resp, token) = login(email, password, &state).await?;
|
||||||
|
|
||||||
// Renews the session key, assigning existing session state to new key.
|
// Renews the session key, assigning existing session state to new key.
|
||||||
session.renew();
|
session.renew();
|
||||||
if let Err(err) = session.insert_token(token) {
|
if let Err(err) = session.insert_token(token) {
|
||||||
// It needs to navigate to login page in web application
|
// It needs to navigate to login page in web application
|
||||||
tracing::error!("Insert session failed: {:?}", err);
|
tracing::error!("Insert session failed: {:?}", err);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(HttpResponse::Ok().json(resp))
|
Ok(HttpResponse::Ok().json(resp))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn logout_handler(req: HttpRequest, state: Data<State>) -> Result<HttpResponse> {
|
async fn logout_handler(req: HttpRequest, state: Data<State>) -> Result<HttpResponse> {
|
||||||
let logged_user = logged_user_from_request(&req, &state.config.application.server_key)?;
|
let logged_user = logged_user_from_request(&req, &state.config.application.server_key)?;
|
||||||
logout(logged_user, state.user.clone()).await;
|
logout(logged_user, state.user.clone()).await;
|
||||||
Ok(HttpResponse::Ok().finish())
|
Ok(HttpResponse::Ok().finish())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(level = "debug", skip(state))]
|
#[tracing::instrument(level = "debug", skip(state))]
|
||||||
async fn register_handler(req: Json<RegisterRequest>, state: Data<State>) -> Result<HttpResponse> {
|
async fn register_handler(req: Json<RegisterRequest>, state: Data<State>) -> Result<HttpResponse> {
|
||||||
let req = req.into_inner();
|
let req = req.into_inner();
|
||||||
let name = UserName::parse(req.name)
|
let name = UserName::parse(req.name)
|
||||||
.map_err(InputParamsError::InvalidName)?
|
.map_err(InputParamsError::InvalidName)?
|
||||||
.0;
|
.0;
|
||||||
let email = UserEmail::parse(req.email)
|
let email = UserEmail::parse(req.email)
|
||||||
.map_err(InputParamsError::InvalidEmail)?
|
.map_err(InputParamsError::InvalidEmail)?
|
||||||
.0;
|
.0;
|
||||||
let password = UserPassword::parse(req.password)
|
let password = UserPassword::parse(req.password)
|
||||||
.map_err(|_| InputParamsError::InvalidPassword)?
|
.map_err(|_| InputParamsError::InvalidPassword)?
|
||||||
.0;
|
.0;
|
||||||
|
|
||||||
let resp = register(name, email, password, &state).await?;
|
let resp = register(name, email, password, &state).await?;
|
||||||
Ok(HttpResponse::Ok().json(resp))
|
Ok(HttpResponse::Ok().json(resp))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn change_password_handler(
|
async fn change_password_handler(
|
||||||
req: HttpRequest,
|
req: HttpRequest,
|
||||||
payload: Json<ChangePasswordRequest>,
|
payload: Json<ChangePasswordRequest>,
|
||||||
// session: SessionToken,
|
// session: SessionToken,
|
||||||
state: Data<State>,
|
state: Data<State>,
|
||||||
) -> Result<HttpResponse> {
|
) -> Result<HttpResponse> {
|
||||||
let logged_user = logged_user_from_request(&req, &state.config.application.server_key)?;
|
let logged_user = logged_user_from_request(&req, &state.config.application.server_key)?;
|
||||||
let payload = payload.into_inner();
|
let payload = payload.into_inner();
|
||||||
if payload.new_password != payload.new_password_confirm {
|
if payload.new_password != payload.new_password_confirm {
|
||||||
return Err(InputParamsError::PasswordNotMatch.into());
|
return Err(InputParamsError::PasswordNotMatch.into());
|
||||||
}
|
}
|
||||||
|
|
||||||
let new_password = UserPassword::parse(payload.new_password)
|
let new_password = UserPassword::parse(payload.new_password)
|
||||||
.map_err(|_| InputParamsError::InvalidPassword)?
|
.map_err(|_| InputParamsError::InvalidPassword)?
|
||||||
.0;
|
.0;
|
||||||
|
|
||||||
change_password(
|
change_password(
|
||||||
state.pg_pool.clone(),
|
state.pg_pool.clone(),
|
||||||
logged_user.clone(),
|
logged_user.clone(),
|
||||||
payload.current_password,
|
payload.current_password,
|
||||||
new_password,
|
new_password,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
Ok(HttpResponse::Ok().finish())
|
Ok(HttpResponse::Ok().finish())
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,32 +1,40 @@
|
||||||
use crate::component::auth::LoggedUser;
|
use crate::component::auth::LoggedUser;
|
||||||
use crate::component::ws::{MessageReceivers, WSClient, WSServer};
|
|
||||||
use crate::state::State;
|
use crate::state::State;
|
||||||
use actix::Addr;
|
use actix::Addr;
|
||||||
use actix_web::web::{Data, Path, Payload};
|
use actix_web::web::{Data, Path, Payload};
|
||||||
use actix_web::{get, web, HttpRequest, HttpResponse, Result, Scope};
|
use actix_web::{get, web, HttpRequest, HttpResponse, Result, Scope};
|
||||||
use actix_web_actors::ws;
|
use actix_web_actors::ws;
|
||||||
|
use secrecy::Secret;
|
||||||
|
use websocket::entities::WSUser;
|
||||||
|
use websocket::{CollabServer, CollabSession};
|
||||||
|
|
||||||
pub fn ws_scope() -> Scope {
|
pub fn ws_scope() -> Scope {
|
||||||
web::scope("/ws").service(establish_ws_connection)
|
web::scope("/ws").service(establish_ws_connection)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/{token}")]
|
#[get("/{token}")]
|
||||||
pub async fn establish_ws_connection(
|
pub async fn establish_ws_connection(
|
||||||
request: HttpRequest,
|
request: HttpRequest,
|
||||||
payload: Payload,
|
payload: Payload,
|
||||||
token: Path<String>,
|
token: Path<String>,
|
||||||
state: Data<State>,
|
state: Data<State>,
|
||||||
server: Data<Addr<WSServer>>,
|
server: Data<Addr<CollabServer>>,
|
||||||
msg_receivers: Data<MessageReceivers>,
|
|
||||||
) -> Result<HttpResponse> {
|
) -> Result<HttpResponse> {
|
||||||
tracing::info!("establish_ws_connection");
|
let user = LoggedUser::from_token(&state.config.application.server_key, token.as_str())?;
|
||||||
let user = LoggedUser::from_token(&state.config.application.server_key, token.as_str())?;
|
let client = CollabSession::new(user.into(), server.get_ref().clone());
|
||||||
let client = WSClient::new(user, server.get_ref().clone(), msg_receivers);
|
match ws::start(client, &request, payload) {
|
||||||
match ws::start(client, &request, payload) {
|
Ok(response) => Ok(response),
|
||||||
Ok(response) => Ok(response),
|
Err(e) => {
|
||||||
Err(e) => {
|
tracing::error!("ws connection error: {:?}", e);
|
||||||
tracing::error!("ws connection error: {:?}", e);
|
Err(e)
|
||||||
Err(e)
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<LoggedUser> for WSUser {
|
||||||
|
fn from(user: LoggedUser) -> Self {
|
||||||
|
Self {
|
||||||
|
user_id: Secret::new(user.expose_secret().to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,9 @@ use actix_session::SessionMiddleware;
|
||||||
use actix_web::cookie::Key;
|
use actix_web::cookie::Key;
|
||||||
use actix_web::{dev::Server, web, web::Data, App, HttpServer};
|
use actix_web::{dev::Server, web, web::Data, App, HttpServer};
|
||||||
|
|
||||||
|
use actix::Actor;
|
||||||
|
|
||||||
|
use collab_persistence::kv::rocks_kv::RocksCollabDB;
|
||||||
use openssl::ssl::{SslAcceptor, SslAcceptorBuilder, SslFiletype, SslMethod};
|
use openssl::ssl::{SslAcceptor, SslAcceptorBuilder, SslFiletype, SslMethod};
|
||||||
use openssl::x509::X509;
|
use openssl::x509::X509;
|
||||||
use secrecy::{ExposeSecret, Secret};
|
use secrecy::{ExposeSecret, Secret};
|
||||||
|
@ -18,121 +21,136 @@ use sqlx::{postgres::PgPoolOptions, PgPool};
|
||||||
use std::net::TcpListener;
|
use std::net::TcpListener;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
use tracing_actix_web::TracingLogger;
|
use tracing_actix_web::TracingLogger;
|
||||||
|
use websocket::CollabServer;
|
||||||
|
|
||||||
pub struct Application {
|
pub struct Application {
|
||||||
port: u16,
|
port: u16,
|
||||||
server: Server,
|
server: Server,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Application {
|
impl Application {
|
||||||
pub async fn build(config: Config, state: State) -> Result<Self, anyhow::Error> {
|
pub async fn build(config: Config, state: State) -> Result<Self, anyhow::Error> {
|
||||||
let address = format!("{}:{}", config.application.host, config.application.port);
|
let address = format!("{}:{}", config.application.host, config.application.port);
|
||||||
let listener = TcpListener::bind(&address)?;
|
let listener = TcpListener::bind(&address)?;
|
||||||
let port = listener.local_addr().unwrap().port();
|
let port = listener.local_addr().unwrap().port();
|
||||||
let server = run(listener, state, config).await?;
|
let server = run(listener, state, config).await?;
|
||||||
|
|
||||||
Ok(Self { port, server })
|
Ok(Self { port, server })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run_until_stopped(self) -> Result<(), std::io::Error> {
|
pub async fn run_until_stopped(self) -> Result<(), std::io::Error> {
|
||||||
self.server.await
|
self.server.await
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn port(&self) -> u16 {
|
pub fn port(&self) -> u16 {
|
||||||
self.port
|
self.port
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run(
|
pub async fn run(
|
||||||
listener: TcpListener,
|
listener: TcpListener,
|
||||||
state: State,
|
state: State,
|
||||||
config: Config,
|
config: Config,
|
||||||
) -> Result<Server, anyhow::Error> {
|
) -> Result<Server, anyhow::Error> {
|
||||||
let redis_store = RedisSessionStore::new(config.redis_uri.expose_secret())
|
let redis_store = RedisSessionStore::new(config.redis_uri.expose_secret())
|
||||||
.await
|
.await
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
anyhow::anyhow!(
|
anyhow::anyhow!(
|
||||||
"Failed to connect to Redis at {:?}: {:?}",
|
"Failed to connect to Redis at {:?}: {:?}",
|
||||||
config.redis_uri,
|
config.redis_uri,
|
||||||
e
|
e
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
let pair = get_certificate_and_server_key(&config);
|
let pair = get_certificate_and_server_key(&config);
|
||||||
let key = pair
|
let key = pair
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|(_, server_key)| Key::from(server_key.expose_secret().as_bytes()))
|
.map(|(_, server_key)| Key::from(server_key.expose_secret().as_bytes()))
|
||||||
.unwrap_or_else(Key::generate);
|
.unwrap_or_else(Key::generate);
|
||||||
let mut server = HttpServer::new(move || {
|
|
||||||
App::new()
|
let collab_server = CollabServer::new(state.rocksdb.clone()).unwrap().start();
|
||||||
// Session middleware
|
let mut server = HttpServer::new(move || {
|
||||||
|
App::new()
|
||||||
.wrap(
|
.wrap(
|
||||||
SessionMiddleware::builder(redis_store.clone(), key.clone())
|
SessionMiddleware::builder(redis_store.clone(), key.clone())
|
||||||
.cookie_name(HEADER_TOKEN.to_string())
|
.cookie_name(HEADER_TOKEN.to_string())
|
||||||
.build(),
|
.build(),
|
||||||
)
|
)
|
||||||
|
// .wrap(ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, add_error_header))
|
||||||
.wrap(IdentityMiddleware::default())
|
.wrap(IdentityMiddleware::default())
|
||||||
.wrap(default_cors())
|
.wrap(default_cors())
|
||||||
.wrap(TracingLogger::default())
|
.wrap(TracingLogger::default())
|
||||||
.app_data(web::JsonConfig::default().limit(4096))
|
.app_data(web::JsonConfig::default().limit(4096))
|
||||||
.service(user_scope())
|
.service(user_scope())
|
||||||
.service(ws_scope())
|
.service(ws_scope())
|
||||||
|
.app_data(Data::new(collab_server.clone()))
|
||||||
.app_data(Data::new(state.clone()))
|
.app_data(Data::new(state.clone()))
|
||||||
});
|
});
|
||||||
|
|
||||||
server = match pair {
|
server = match pair {
|
||||||
None => server.listen(listener)?,
|
None => server.listen(listener)?,
|
||||||
Some((certificate, _)) => {
|
Some((certificate, _)) => {
|
||||||
server.listen_openssl(listener, make_ssl_acceptor_builder(certificate))?
|
server.listen_openssl(listener, make_ssl_acceptor_builder(certificate))?
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(server.run())
|
Ok(server.run())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_certificate_and_server_key(config: &Config) -> Option<(Secret<String>, Secret<String>)> {
|
fn get_certificate_and_server_key(config: &Config) -> Option<(Secret<String>, Secret<String>)> {
|
||||||
let tls_config = config.application.tls_config.as_ref()?;
|
let tls_config = config.application.tls_config.as_ref()?;
|
||||||
match tls_config {
|
match tls_config {
|
||||||
TlsConfig::NoTls => None,
|
TlsConfig::NoTls => None,
|
||||||
TlsConfig::SelfSigned => Some(create_self_signed_certificate().unwrap()),
|
TlsConfig::SelfSigned => Some(create_self_signed_certificate().unwrap()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn init_state(config: &Config) -> State {
|
pub async fn init_state(config: &Config) -> State {
|
||||||
let pg_pool = get_connection_pool(&config.database)
|
let pg_pool = get_connection_pool(&config.database)
|
||||||
.await
|
.await
|
||||||
.unwrap_or_else(|_| panic!("Failed to connect to Postgres at {:?}.", config.database));
|
.unwrap_or_else(|_| panic!("Failed to connect to Postgres at {:?}.", config.database));
|
||||||
|
|
||||||
State {
|
std::fs::create_dir_all(config.application.rocksdb_db_dir()).expect("create rocksdb db dir");
|
||||||
pg_pool,
|
let rocksdb = Arc::new(RocksCollabDB::open(config.application.rocksdb_db_dir()).unwrap());
|
||||||
config: Arc::new(config.clone()),
|
State {
|
||||||
user: Arc::new(Default::default()),
|
pg_pool,
|
||||||
id_gen: Arc::new(RwLock::new(Snowflake::new(1))),
|
rocksdb,
|
||||||
}
|
config: Arc::new(config.clone()),
|
||||||
|
user: Arc::new(Default::default()),
|
||||||
|
id_gen: Arc::new(RwLock::new(Snowflake::new(1))),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_connection_pool(setting: &DatabaseSetting) -> Result<PgPool, sqlx::Error> {
|
pub async fn get_connection_pool(setting: &DatabaseSetting) -> Result<PgPool, sqlx::Error> {
|
||||||
PgPoolOptions::new()
|
PgPoolOptions::new()
|
||||||
.acquire_timeout(std::time::Duration::from_secs(5))
|
.acquire_timeout(std::time::Duration::from_secs(5))
|
||||||
.connect_with(setting.with_db())
|
.connect_with(setting.with_db())
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
fn make_ssl_acceptor_builder(certificate: Secret<String>) -> SslAcceptorBuilder {
|
fn make_ssl_acceptor_builder(certificate: Secret<String>) -> SslAcceptorBuilder {
|
||||||
let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
|
let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
|
||||||
let x509_cert = X509::from_pem(certificate.expose_secret().as_bytes()).unwrap();
|
let x509_cert = X509::from_pem(certificate.expose_secret().as_bytes()).unwrap();
|
||||||
builder.set_certificate(&x509_cert).unwrap();
|
builder.set_certificate(&x509_cert).unwrap();
|
||||||
builder
|
builder
|
||||||
.set_private_key_file("./cert/key.pem", SslFiletype::PEM)
|
.set_private_key_file("./cert/key.pem", SslFiletype::PEM)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
builder
|
builder
|
||||||
.set_certificate_chain_file("./cert/cert.pem")
|
.set_certificate_chain_file("./cert/cert.pem")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
builder
|
builder
|
||||||
.set_min_proto_version(Some(openssl::ssl::SslVersion::TLS1_2))
|
.set_min_proto_version(Some(openssl::ssl::SslVersion::TLS1_2))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
builder
|
builder
|
||||||
.set_max_proto_version(Some(openssl::ssl::SslVersion::TLS1_3))
|
.set_max_proto_version(Some(openssl::ssl::SslVersion::TLS1_3))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
builder
|
builder
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// fn add_error_header<B>(
|
||||||
|
// res: dev::ServiceResponse<B>,
|
||||||
|
// ) -> Result<ErrorHandlerResponse<B>, actix_web::Error> {
|
||||||
|
// tracing::error!("{:?}", res.request());
|
||||||
|
// Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
|
||||||
|
// }
|
||||||
|
|
|
@ -5,76 +5,76 @@ use thiserror::Error;
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum AuthError {
|
pub enum AuthError {
|
||||||
#[error("Credentials is invalid")]
|
#[error("Credentials is invalid")]
|
||||||
InvalidCredentials(#[source] anyhow::Error),
|
InvalidCredentials(#[source] anyhow::Error),
|
||||||
|
|
||||||
#[error("User is not exist")]
|
#[error("User is not exist")]
|
||||||
UserNotExist(#[source] anyhow::Error),
|
UserNotExist(#[source] anyhow::Error),
|
||||||
|
|
||||||
#[error("{} is already used", email)]
|
#[error("{} is already used", email)]
|
||||||
UserAlreadyExist { email: String },
|
UserAlreadyExist { email: String },
|
||||||
|
|
||||||
#[error("Invalid password")]
|
#[error("Invalid password")]
|
||||||
InvalidPassword,
|
InvalidPassword,
|
||||||
|
|
||||||
#[error("User is unauthorized")]
|
#[error("User is unauthorized")]
|
||||||
Unauthorized,
|
Unauthorized,
|
||||||
|
|
||||||
#[error("User internal error")]
|
#[error("User internal error")]
|
||||||
InternalError(#[from] anyhow::Error),
|
InternalError(#[from] anyhow::Error),
|
||||||
|
|
||||||
#[error("Parser uuid failed: {}", err)]
|
#[error("Parser uuid failed: {}", err)]
|
||||||
InvalidUuid { err: String },
|
InvalidUuid { err: String },
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn internal_error(error: anyhow::Error) -> AuthError {
|
pub fn internal_error(error: anyhow::Error) -> AuthError {
|
||||||
AuthError::InternalError(error)
|
AuthError::InternalError(error)
|
||||||
}
|
}
|
||||||
|
|
||||||
impl actix_web::error::ResponseError for AuthError {
|
impl actix_web::error::ResponseError for AuthError {
|
||||||
fn status_code(&self) -> StatusCode {
|
fn status_code(&self) -> StatusCode {
|
||||||
match *self {
|
match *self {
|
||||||
AuthError::InvalidCredentials(_) => StatusCode::UNAUTHORIZED,
|
AuthError::InvalidCredentials(_) => StatusCode::UNAUTHORIZED,
|
||||||
AuthError::UserNotExist(_) => StatusCode::UNAUTHORIZED,
|
AuthError::UserNotExist(_) => StatusCode::UNAUTHORIZED,
|
||||||
AuthError::UserAlreadyExist { .. } => StatusCode::BAD_REQUEST,
|
AuthError::UserAlreadyExist { .. } => StatusCode::BAD_REQUEST,
|
||||||
AuthError::InvalidPassword => StatusCode::UNAUTHORIZED,
|
AuthError::InvalidPassword => StatusCode::UNAUTHORIZED,
|
||||||
AuthError::Unauthorized => StatusCode::UNAUTHORIZED,
|
AuthError::Unauthorized => StatusCode::UNAUTHORIZED,
|
||||||
AuthError::InternalError(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
AuthError::InternalError(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
AuthError::InvalidUuid { .. } => StatusCode::UNAUTHORIZED,
|
AuthError::InvalidUuid { .. } => StatusCode::UNAUTHORIZED,
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn error_response(&self) -> HttpResponse {
|
fn error_response(&self) -> HttpResponse {
|
||||||
HttpResponse::build(self.status_code()).body(self.to_string())
|
HttpResponse::build(self.status_code()).body(self.to_string())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum InputParamsError {
|
pub enum InputParamsError {
|
||||||
#[error("Invalid name")]
|
#[error("Invalid name")]
|
||||||
InvalidName(String),
|
InvalidName(String),
|
||||||
|
|
||||||
#[error("Invalid email format")]
|
#[error("Invalid email format")]
|
||||||
InvalidEmail(String),
|
InvalidEmail(String),
|
||||||
|
|
||||||
#[error("Invalid password")]
|
#[error("Invalid password")]
|
||||||
InvalidPassword,
|
InvalidPassword,
|
||||||
|
|
||||||
#[error("You entered two different new passwords")]
|
#[error("You entered two different new passwords")]
|
||||||
PasswordNotMatch,
|
PasswordNotMatch,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl actix_web::error::ResponseError for InputParamsError {
|
impl actix_web::error::ResponseError for InputParamsError {
|
||||||
fn status_code(&self) -> StatusCode {
|
fn status_code(&self) -> StatusCode {
|
||||||
match *self {
|
match *self {
|
||||||
InputParamsError::InvalidName(_) => StatusCode::BAD_REQUEST,
|
InputParamsError::InvalidName(_) => StatusCode::BAD_REQUEST,
|
||||||
InputParamsError::InvalidEmail(_) => StatusCode::BAD_REQUEST,
|
InputParamsError::InvalidEmail(_) => StatusCode::BAD_REQUEST,
|
||||||
InputParamsError::InvalidPassword => StatusCode::BAD_REQUEST,
|
InputParamsError::InvalidPassword => StatusCode::BAD_REQUEST,
|
||||||
InputParamsError::PasswordNotMatch => StatusCode::BAD_REQUEST,
|
InputParamsError::PasswordNotMatch => StatusCode::BAD_REQUEST,
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn error_response(&self) -> HttpResponse {
|
fn error_response(&self) -> HttpResponse {
|
||||||
HttpResponse::build(self.status_code()).body(self.to_string())
|
HttpResponse::build(self.status_code()).body(self.to_string())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,84 +8,85 @@ use secrecy::{ExposeSecret, Secret};
|
||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
|
|
||||||
pub struct Credentials {
|
pub struct Credentials {
|
||||||
pub email: String,
|
pub email: String,
|
||||||
pub password: Secret<String>,
|
pub password: Secret<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(level = "debug", skip(credentials, pool))]
|
#[tracing::instrument(level = "debug", skip(credentials, pool))]
|
||||||
pub async fn validate_credentials(
|
pub async fn validate_credentials(
|
||||||
credentials: Credentials,
|
credentials: Credentials,
|
||||||
pool: &PgPool,
|
pool: &PgPool,
|
||||||
) -> Result<i64, AuthError> {
|
) -> Result<i64, AuthError> {
|
||||||
let mut uid = None;
|
let mut uid = None;
|
||||||
let mut expected_hash_password = Secret::new(
|
let mut expected_hash_password = Secret::new(
|
||||||
"$argon2id$v=19$m=15000,t=2,p=1$\
|
"$argon2id$v=19$m=15000,t=2,p=1$\
|
||||||
gZiV/M1gPc22ElAH/Jh1Hw$\
|
gZiV/M1gPc22ElAH/Jh1Hw$\
|
||||||
CWOrkoo7oJBQ/iyh7uJ0LO2aLEfrHwTWllSAxT0zRno"
|
CWOrkoo7oJBQ/iyh7uJ0LO2aLEfrHwTWllSAxT0zRno"
|
||||||
.to_string(),
|
.to_string(),
|
||||||
);
|
);
|
||||||
|
|
||||||
if let Some((stored_uid, stored_hash_password)) =
|
if let Some((stored_uid, stored_hash_password)) =
|
||||||
get_stored_credentials(&credentials.email, pool).await?
|
get_stored_credentials(&credentials.email, pool).await?
|
||||||
{
|
{
|
||||||
uid = Some(stored_uid);
|
uid = Some(stored_uid);
|
||||||
expected_hash_password = stored_hash_password;
|
expected_hash_password = stored_hash_password;
|
||||||
}
|
}
|
||||||
|
|
||||||
spawn_blocking_with_tracing(move || {
|
spawn_blocking_with_tracing(move || {
|
||||||
verify_password_hash(expected_hash_password, credentials.password)
|
verify_password_hash(expected_hash_password, credentials.password)
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.context("Failed to spawn blocking task.")??;
|
.context("Failed to spawn blocking task.")??;
|
||||||
|
|
||||||
uid.ok_or_else(|| anyhow::anyhow!("Unknown email."))
|
uid
|
||||||
.map_err(AuthError::InvalidCredentials)
|
.ok_or_else(|| anyhow::anyhow!("Unknown email."))
|
||||||
|
.map_err(AuthError::InvalidCredentials)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn compute_hash_password(password: &[u8]) -> Result<Secret<String>, anyhow::Error> {
|
pub fn compute_hash_password(password: &[u8]) -> Result<Secret<String>, anyhow::Error> {
|
||||||
let salt = SaltString::generate(&mut rand::thread_rng());
|
let salt = SaltString::generate(&mut rand::thread_rng());
|
||||||
let password = Argon2::new(
|
let password = Argon2::new(
|
||||||
Algorithm::Argon2id,
|
Algorithm::Argon2id,
|
||||||
Version::V0x13,
|
Version::V0x13,
|
||||||
Params::new(15000, 2, 1, None).unwrap(),
|
Params::new(15000, 2, 1, None).unwrap(),
|
||||||
)
|
)
|
||||||
.hash_password(password, &salt)?
|
.hash_password(password, &salt)?
|
||||||
.to_string();
|
.to_string();
|
||||||
Ok(Secret::new(password))
|
Ok(Secret::new(password))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(level = "debug", skip(email, pool))]
|
#[tracing::instrument(level = "debug", skip(email, pool))]
|
||||||
async fn get_stored_credentials(
|
async fn get_stored_credentials(
|
||||||
email: &str,
|
email: &str,
|
||||||
pool: &PgPool,
|
pool: &PgPool,
|
||||||
) -> Result<Option<(i64, Secret<String>)>, anyhow::Error> {
|
) -> Result<Option<(i64, Secret<String>)>, anyhow::Error> {
|
||||||
let row = sqlx::query!(
|
let row = sqlx::query!(
|
||||||
r#"
|
r#"
|
||||||
SELECT uid, password
|
SELECT uid, password
|
||||||
FROM users
|
FROM users
|
||||||
WHERE email = $1
|
WHERE email = $1
|
||||||
"#,
|
"#,
|
||||||
email,
|
email,
|
||||||
)
|
)
|
||||||
.fetch_optional(pool)
|
.fetch_optional(pool)
|
||||||
.await
|
.await
|
||||||
.context("Failed to performed a query to retrieve stored credentials.")?
|
.context("Failed to performed a query to retrieve stored credentials.")?
|
||||||
.map(|row| (row.uid, Secret::new(row.password)));
|
.map(|row| (row.uid, Secret::new(row.password)));
|
||||||
Ok(row)
|
Ok(row)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn verify_password_hash(
|
fn verify_password_hash(
|
||||||
expected_password_hash: Secret<String>,
|
expected_password_hash: Secret<String>,
|
||||||
password_candidate: Secret<String>,
|
password_candidate: Secret<String>,
|
||||||
) -> Result<(), AuthError> {
|
) -> Result<(), AuthError> {
|
||||||
let expected_hash_password = PasswordHash::new(expected_password_hash.expose_secret())
|
let expected_hash_password = PasswordHash::new(expected_password_hash.expose_secret())
|
||||||
.context("Failed to parse hash in PHC string format.")?;
|
.context("Failed to parse hash in PHC string format.")?;
|
||||||
|
|
||||||
Argon2::default()
|
Argon2::default()
|
||||||
.verify_password(
|
.verify_password(
|
||||||
password_candidate.expose_secret().as_bytes(),
|
password_candidate.expose_secret().as_bytes(),
|
||||||
&expected_hash_password,
|
&expected_hash_password,
|
||||||
)
|
)
|
||||||
.context("Invalid password.")
|
.context("Invalid password.")
|
||||||
.map_err(|_| AuthError::InvalidPassword)
|
.map_err(|_| AuthError::InvalidPassword)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::component::auth::{
|
use crate::component::auth::{
|
||||||
compute_hash_password, internal_error, validate_credentials, AuthError, Credentials,
|
compute_hash_password, internal_error, validate_credentials, AuthError, Credentials,
|
||||||
};
|
};
|
||||||
use crate::config::env::domain;
|
use crate::config::env::domain;
|
||||||
use crate::state::{State, UserCache};
|
use crate::state::{State, UserCache};
|
||||||
|
@ -18,234 +18,234 @@ use token::{create_token, parse_token, TokenError};
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
pub async fn login(
|
pub async fn login(
|
||||||
email: String,
|
email: String,
|
||||||
password: String,
|
password: String,
|
||||||
state: &State,
|
state: &State,
|
||||||
) -> Result<(LoginResponse, Secret<Token>), AuthError> {
|
) -> Result<(LoginResponse, Secret<Token>), AuthError> {
|
||||||
let credentials = Credentials {
|
let credentials = Credentials {
|
||||||
email,
|
email,
|
||||||
password: Secret::new(password),
|
password: Secret::new(password),
|
||||||
};
|
};
|
||||||
let server_key = &state.config.application.server_key;
|
let server_key = &state.config.application.server_key;
|
||||||
|
|
||||||
match validate_credentials(credentials, &state.pg_pool).await {
|
match validate_credentials(credentials, &state.pg_pool).await {
|
||||||
Ok(uid) => {
|
Ok(uid) => {
|
||||||
let token = Token::create_token(uid, server_key)?;
|
let token = Token::create_token(uid, server_key)?;
|
||||||
let logged_user = LoggedUser::new(uid);
|
let logged_user = LoggedUser::new(uid);
|
||||||
state.user.write().await.authorized(logged_user);
|
state.user.write().await.authorized(logged_user);
|
||||||
Ok((
|
Ok((
|
||||||
LoginResponse {
|
LoginResponse {
|
||||||
token: token.0.clone(),
|
token: token.0.clone(),
|
||||||
uid: uid.to_string(),
|
uid: uid.to_string(),
|
||||||
},
|
},
|
||||||
Secret::new(token),
|
Secret::new(token),
|
||||||
))
|
))
|
||||||
}
|
},
|
||||||
Err(err) => Err(err),
|
Err(err) => Err(err),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn logout(logged_user: LoggedUser, cache: Arc<RwLock<UserCache>>) {
|
pub async fn logout(logged_user: LoggedUser, cache: Arc<RwLock<UserCache>>) {
|
||||||
cache.write().await.unauthorized(logged_user);
|
cache.write().await.unauthorized(logged_user);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn register(
|
pub async fn register(
|
||||||
username: String,
|
username: String,
|
||||||
email: String,
|
email: String,
|
||||||
password: String,
|
password: String,
|
||||||
state: &State,
|
state: &State,
|
||||||
) -> Result<RegisterResponse, AuthError> {
|
) -> Result<RegisterResponse, AuthError> {
|
||||||
let pg_pool = state.pg_pool.clone();
|
let pg_pool = state.pg_pool.clone();
|
||||||
let server_key = &state.config.application.server_key;
|
let server_key = &state.config.application.server_key;
|
||||||
let mut transaction = pg_pool
|
let mut transaction = pg_pool
|
||||||
.begin()
|
.begin()
|
||||||
.await
|
.await
|
||||||
.context("Failed to acquire a Postgres connection to register user")
|
.context("Failed to acquire a Postgres connection to register user")
|
||||||
.map_err(internal_error)?;
|
.map_err(internal_error)?;
|
||||||
|
|
||||||
if is_email_exist(&mut transaction, email.as_ref())
|
if is_email_exist(&mut transaction, email.as_ref())
|
||||||
.await
|
.await
|
||||||
.map_err(internal_error)?
|
.map_err(internal_error)?
|
||||||
{
|
{
|
||||||
return Err(AuthError::UserAlreadyExist { email });
|
return Err(AuthError::UserAlreadyExist { email });
|
||||||
}
|
}
|
||||||
|
|
||||||
let uid = state.id_gen.write().await.next_id();
|
let uid = state.id_gen.write().await.next_id();
|
||||||
let token = Token::create_token(uid, server_key)?;
|
let token = Token::create_token(uid, server_key)?;
|
||||||
let password = compute_hash_password(password.as_bytes()).map_err(internal_error)?;
|
let password = compute_hash_password(password.as_bytes()).map_err(internal_error)?;
|
||||||
let _ = sqlx::query!(
|
let _ = sqlx::query!(
|
||||||
r#"
|
r#"
|
||||||
INSERT INTO users (uid, email, username, create_time, password)
|
INSERT INTO users (uid, email, username, create_time, password)
|
||||||
VALUES ($1, $2, $3, $4, $5)
|
VALUES ($1, $2, $3, $4, $5)
|
||||||
"#,
|
"#,
|
||||||
uid,
|
uid,
|
||||||
email,
|
email,
|
||||||
username,
|
username,
|
||||||
Utc::now(),
|
Utc::now(),
|
||||||
password.expose_secret(),
|
password.expose_secret(),
|
||||||
)
|
)
|
||||||
.execute(&mut transaction)
|
.execute(&mut transaction)
|
||||||
|
.await
|
||||||
|
.context("Save user to disk failed")
|
||||||
|
.map_err(internal_error)?;
|
||||||
|
|
||||||
|
transaction
|
||||||
|
.commit()
|
||||||
.await
|
.await
|
||||||
.context("Save user to disk failed")
|
.context("Failed to commit SQL transaction to register user.")
|
||||||
.map_err(internal_error)?;
|
.map_err(internal_error)?;
|
||||||
|
|
||||||
transaction
|
let logged_user = LoggedUser::new(uid);
|
||||||
.commit()
|
state.user.write().await.authorized(logged_user);
|
||||||
.await
|
|
||||||
.context("Failed to commit SQL transaction to register user.")
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
|
|
||||||
let logged_user = LoggedUser::new(uid);
|
Ok(RegisterResponse {
|
||||||
state.user.write().await.authorized(logged_user);
|
token: token.0.clone(),
|
||||||
|
})
|
||||||
Ok(RegisterResponse {
|
|
||||||
token: token.0.clone(),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn change_password(
|
pub async fn change_password(
|
||||||
pg_pool: PgPool,
|
pg_pool: PgPool,
|
||||||
logged_user: LoggedUser,
|
logged_user: LoggedUser,
|
||||||
current_password: String,
|
current_password: String,
|
||||||
new_password: String,
|
new_password: String,
|
||||||
) -> Result<(), AuthError> {
|
) -> Result<(), AuthError> {
|
||||||
let mut transaction = pg_pool
|
let mut transaction = pg_pool
|
||||||
.begin()
|
.begin()
|
||||||
.await
|
.await
|
||||||
.context("Failed to acquire a Postgres connection to change password")
|
.context("Failed to acquire a Postgres connection to change password")
|
||||||
.map_err(internal_error)?;
|
.map_err(internal_error)?;
|
||||||
|
|
||||||
let email = get_user_email(*logged_user.expose_secret(), &mut transaction).await?;
|
let email = get_user_email(*logged_user.expose_secret(), &mut transaction).await?;
|
||||||
|
|
||||||
// check password
|
// check password
|
||||||
let credentials = Credentials {
|
let credentials = Credentials {
|
||||||
email,
|
email,
|
||||||
password: Secret::new(current_password),
|
password: Secret::new(current_password),
|
||||||
};
|
};
|
||||||
let _ = validate_credentials(credentials, &pg_pool).await?;
|
let _ = validate_credentials(credentials, &pg_pool).await?;
|
||||||
|
|
||||||
// Hash password
|
// Hash password
|
||||||
let new_hash_password =
|
let new_hash_password =
|
||||||
spawn_blocking_with_tracing(move || compute_hash_password(new_password.as_bytes()))
|
spawn_blocking_with_tracing(move || compute_hash_password(new_password.as_bytes()))
|
||||||
.await
|
.await
|
||||||
.context("Failed to hash password")??;
|
.context("Failed to hash password")??;
|
||||||
|
|
||||||
// Save password to disk
|
// Save password to disk
|
||||||
let sql = "UPDATE users SET password = $1 where uid = $2";
|
let sql = "UPDATE users SET password = $1 where uid = $2";
|
||||||
let _ = sqlx::query(sql)
|
let _ = sqlx::query(sql)
|
||||||
.bind(new_hash_password.expose_secret())
|
.bind(new_hash_password.expose_secret())
|
||||||
.bind(logged_user.expose_secret())
|
.bind(logged_user.expose_secret())
|
||||||
.execute(&mut transaction)
|
.execute(&mut transaction)
|
||||||
.await
|
.await
|
||||||
.context("Failed to change user's password in the database.")?;
|
.context("Failed to change user's password in the database.")?;
|
||||||
|
|
||||||
transaction
|
transaction
|
||||||
.commit()
|
.commit()
|
||||||
.await
|
.await
|
||||||
.context("Failed to commit SQL transaction to change user's password")
|
.context("Failed to commit SQL transaction to change user's password")
|
||||||
.map_err(internal_error)?;
|
.map_err(internal_error)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_user_email(
|
pub async fn get_user_email(
|
||||||
uid: i64,
|
uid: i64,
|
||||||
transaction: &mut Transaction<'_, Postgres>,
|
transaction: &mut Transaction<'_, Postgres>,
|
||||||
) -> Result<String, anyhow::Error> {
|
) -> Result<String, anyhow::Error> {
|
||||||
let row = sqlx::query!(
|
let row = sqlx::query!(
|
||||||
r#"
|
r#"
|
||||||
SELECT email
|
SELECT email
|
||||||
FROM users
|
FROM users
|
||||||
WHERE uid = $1
|
WHERE uid = $1
|
||||||
"#,
|
"#,
|
||||||
uid,
|
uid,
|
||||||
)
|
)
|
||||||
.fetch_one(transaction)
|
.fetch_one(transaction)
|
||||||
.await
|
.await
|
||||||
.context("Failed to retrieve the username`")?;
|
.context("Failed to retrieve the username`")?;
|
||||||
Ok(row.email)
|
Ok(row.email)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// TODO: cache this state in [State]
|
/// TODO: cache this state in [State]
|
||||||
async fn is_email_exist(
|
async fn is_email_exist(
|
||||||
transaction: &mut Transaction<'_, Postgres>,
|
transaction: &mut Transaction<'_, Postgres>,
|
||||||
email: &str,
|
email: &str,
|
||||||
) -> Result<bool, anyhow::Error> {
|
) -> Result<bool, anyhow::Error> {
|
||||||
let result = sqlx::query(r#"SELECT email FROM users WHERE email = $1"#)
|
let result = sqlx::query(r#"SELECT email FROM users WHERE email = $1"#)
|
||||||
.bind(email)
|
.bind(email)
|
||||||
.fetch_optional(transaction)
|
.fetch_optional(transaction)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
Ok(result.is_some())
|
Ok(result.is_some())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default, Deserialize, Debug)]
|
#[derive(Default, Deserialize, Debug)]
|
||||||
pub struct LoginRequest {
|
pub struct LoginRequest {
|
||||||
pub email: String,
|
pub email: String,
|
||||||
pub password: String,
|
pub password: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default, Serialize, Deserialize, Debug)]
|
#[derive(Default, Serialize, Deserialize, Debug)]
|
||||||
pub struct LoginResponse {
|
pub struct LoginResponse {
|
||||||
pub token: String,
|
pub token: String,
|
||||||
pub uid: String,
|
pub uid: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default, Deserialize, Debug)]
|
#[derive(Default, Deserialize, Debug)]
|
||||||
pub struct RegisterRequest {
|
pub struct RegisterRequest {
|
||||||
pub email: String,
|
pub email: String,
|
||||||
pub password: String,
|
pub password: String,
|
||||||
pub name: String,
|
pub name: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default, Serialize, Deserialize, Debug)]
|
#[derive(Default, Serialize, Deserialize, Debug)]
|
||||||
pub struct RegisterResponse {
|
pub struct RegisterResponse {
|
||||||
pub token: String,
|
pub token: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default, Deserialize, Debug)]
|
#[derive(Default, Deserialize, Debug)]
|
||||||
pub struct ChangePasswordRequest {
|
pub struct ChangePasswordRequest {
|
||||||
pub current_password: String,
|
pub current_password: String,
|
||||||
pub new_password: String,
|
pub new_password: String,
|
||||||
pub new_password_confirm: String,
|
pub new_password_confirm: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Default)]
|
#[derive(Clone, Default)]
|
||||||
pub struct WrapI64(i64);
|
pub struct SecretI64(i64);
|
||||||
impl Copy for WrapI64 {}
|
impl Copy for SecretI64 {}
|
||||||
impl DefaultIsZeroes for WrapI64 {}
|
impl DefaultIsZeroes for SecretI64 {}
|
||||||
impl DebugSecret for WrapI64 {}
|
impl DebugSecret for SecretI64 {}
|
||||||
impl CloneableSecret for WrapI64 {}
|
impl CloneableSecret for SecretI64 {}
|
||||||
|
|
||||||
impl std::ops::Deref for WrapI64 {
|
impl std::ops::Deref for SecretI64 {
|
||||||
type Target = i64;
|
type Target = i64;
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
fn deref(&self) -> &Self::Target {
|
||||||
&self.0
|
&self.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct LoggedUser(Secret<WrapI64>);
|
pub struct LoggedUser(Secret<SecretI64>);
|
||||||
|
|
||||||
impl From<Claim> for LoggedUser {
|
impl From<Claim> for LoggedUser {
|
||||||
fn from(c: Claim) -> Self {
|
fn from(c: Claim) -> Self {
|
||||||
Self(Secret::new(WrapI64(c.uid)))
|
Self(Secret::new(SecretI64(c.uid)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LoggedUser {
|
impl LoggedUser {
|
||||||
pub fn new(uid: i64) -> Self {
|
pub fn new(uid: i64) -> Self {
|
||||||
Self(Secret::new(WrapI64(uid)))
|
Self(Secret::new(SecretI64(uid)))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_token(server_key: &Secret<String>, token: &str) -> Result<Self, AuthError> {
|
pub fn from_token(server_key: &Secret<String>, token: &str) -> Result<Self, AuthError> {
|
||||||
let user: LoggedUser = Token::decode_token(server_key, token)?.into();
|
let user: LoggedUser = Token::decode_token(server_key, token)?.into();
|
||||||
Ok(user)
|
Ok(user)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn expose_secret(&self) -> &i64 {
|
pub fn expose_secret(&self) -> &i64 {
|
||||||
self.0.expose_secret()
|
self.0.expose_secret()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const HEADER_TOKEN: &str = "token";
|
pub const HEADER_TOKEN: &str = "token";
|
||||||
|
@ -253,68 +253,68 @@ pub const EXPIRED_DURATION_DAYS: i64 = 30;
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct Claim {
|
pub struct Claim {
|
||||||
iss: String,
|
iss: String,
|
||||||
uid: i64,
|
uid: i64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Claim {
|
impl Claim {
|
||||||
pub fn with_user_id(uid: i64) -> Self {
|
pub fn with_user_id(uid: i64) -> Self {
|
||||||
Self { iss: domain(), uid }
|
Self { iss: domain(), uid }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Default, Serialize, Deserialize)]
|
#[derive(Clone, Default, Serialize, Deserialize)]
|
||||||
pub struct Token(pub String);
|
pub struct Token(pub String);
|
||||||
|
|
||||||
impl Zeroize for Token {
|
impl Zeroize for Token {
|
||||||
fn zeroize(&mut self) {
|
fn zeroize(&mut self) {
|
||||||
self.0.zeroize()
|
self.0.zeroize()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Token {
|
impl Token {
|
||||||
pub fn create_token(uid: i64, server_key: &Secret<String>) -> Result<Self, AuthError> {
|
pub fn create_token(uid: i64, server_key: &Secret<String>) -> Result<Self, AuthError> {
|
||||||
let claim = Claim::with_user_id(uid);
|
let claim = Claim::with_user_id(uid);
|
||||||
let token = create_token(
|
let token = create_token(
|
||||||
server_key.expose_secret().as_str(),
|
server_key.expose_secret().as_str(),
|
||||||
claim,
|
claim,
|
||||||
Duration::days(EXPIRED_DURATION_DAYS),
|
Duration::days(EXPIRED_DURATION_DAYS),
|
||||||
)
|
)
|
||||||
.map_err(|e| match e {
|
.map_err(|e| match e {
|
||||||
TokenError::Jwt(_) => AuthError::Unauthorized,
|
TokenError::Jwt(_) => AuthError::Unauthorized,
|
||||||
TokenError::Expired => AuthError::Unauthorized,
|
TokenError::Expired => AuthError::Unauthorized,
|
||||||
})?;
|
})?;
|
||||||
Ok(Self(token))
|
Ok(Self(token))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn decode_token(server_key: &Secret<String>, token: &str) -> Result<Claim, AuthError> {
|
pub fn decode_token(server_key: &Secret<String>, token: &str) -> Result<Claim, AuthError> {
|
||||||
parse_token::<Claim>(server_key.expose_secret().as_str(), token)
|
parse_token::<Claim>(server_key.expose_secret().as_str(), token)
|
||||||
.map_err(|_| AuthError::Unauthorized)
|
.map_err(|_| AuthError::Unauthorized)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn logged_user_from_request(
|
pub fn logged_user_from_request(
|
||||||
request: &HttpRequest,
|
request: &HttpRequest,
|
||||||
server_key: &Secret<String>,
|
server_key: &Secret<String>,
|
||||||
) -> Result<LoggedUser, AuthError> {
|
) -> Result<LoggedUser, AuthError> {
|
||||||
match request.headers().get(HEADER_TOKEN) {
|
match request.headers().get(HEADER_TOKEN) {
|
||||||
None => Err(AuthError::Unauthorized),
|
None => Err(AuthError::Unauthorized),
|
||||||
Some(header) => match header.to_str() {
|
Some(header) => match header.to_str() {
|
||||||
Ok(token_str) => LoggedUser::from_token(server_key, token_str),
|
Ok(token_str) => LoggedUser::from_token(server_key, token_str),
|
||||||
Err(_) => Err(AuthError::Unauthorized),
|
Err(_) => Err(AuthError::Unauthorized),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn uid_from_request(
|
pub fn uid_from_request(
|
||||||
request: &HttpRequest,
|
request: &HttpRequest,
|
||||||
server_key: &Secret<String>,
|
server_key: &Secret<String>,
|
||||||
) -> Result<Secret<i64>, AuthError> {
|
) -> Result<Secret<i64>, AuthError> {
|
||||||
match request.headers().get(HEADER_TOKEN) {
|
match request.headers().get(HEADER_TOKEN) {
|
||||||
Some(header) => match header.to_str() {
|
Some(header) => match header.to_str() {
|
||||||
Ok(val) => Token::decode_token(server_key, val).map(|claim| Secret::new(claim.uid)),
|
Ok(val) => Token::decode_token(server_key, val).map(|claim| Secret::new(claim.uid)),
|
||||||
Err(_) => Err(AuthError::Unauthorized),
|
Err(_) => Err(AuthError::Unauthorized),
|
||||||
},
|
},
|
||||||
None => Err(AuthError::Unauthorized),
|
None => Err(AuthError::Unauthorized),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,3 +1,2 @@
|
||||||
pub mod auth;
|
pub mod auth;
|
||||||
pub mod token_state;
|
pub mod token_state;
|
||||||
pub mod ws;
|
|
||||||
|
|
|
@ -9,30 +9,30 @@ use std::future::{ready, Ready};
|
||||||
pub struct SessionToken(Session);
|
pub struct SessionToken(Session);
|
||||||
|
|
||||||
impl SessionToken {
|
impl SessionToken {
|
||||||
const TOKEN_ID_KEY: &'static str = "token";
|
const TOKEN_ID_KEY: &'static str = "token";
|
||||||
|
|
||||||
pub fn renew(&self) {
|
pub fn renew(&self) {
|
||||||
self.0.renew();
|
self.0.renew();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn insert_token(&self, token: Secret<Token>) -> Result<(), SessionInsertError> {
|
pub fn insert_token(&self, token: Secret<Token>) -> Result<(), SessionInsertError> {
|
||||||
self.0.insert(Self::TOKEN_ID_KEY, token.expose_secret())
|
self.0.insert(Self::TOKEN_ID_KEY, token.expose_secret())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_token(&self) -> Result<Option<String>, SessionGetError> {
|
pub fn get_token(&self) -> Result<Option<String>, SessionGetError> {
|
||||||
self.0.get(Self::TOKEN_ID_KEY)
|
self.0.get(Self::TOKEN_ID_KEY)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn log_out(self) {
|
pub fn log_out(self) {
|
||||||
self.0.purge()
|
self.0.purge()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FromRequest for SessionToken {
|
impl FromRequest for SessionToken {
|
||||||
type Error = <Session as FromRequest>::Error;
|
type Error = <Session as FromRequest>::Error;
|
||||||
type Future = Ready<Result<SessionToken, Self::Error>>;
|
type Future = Ready<Result<SessionToken, Self::Error>>;
|
||||||
|
|
||||||
fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future {
|
fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future {
|
||||||
ready(Ok(SessionToken(req.get_session())))
|
ready(Ok(SessionToken(req.get_session())))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,164 +0,0 @@
|
||||||
use crate::component::auth::LoggedUser;
|
|
||||||
use crate::component::ws::entities::{
|
|
||||||
Connect, Disconnect, MessageDetail, MessagePayload, Socket, WebSocketMessage,
|
|
||||||
};
|
|
||||||
use crate::component::ws::server::WSServer;
|
|
||||||
use crate::component::ws::{HEARTBEAT_INTERVAL, PING_TIMEOUT};
|
|
||||||
use actix::*;
|
|
||||||
use actix_http::ws::Message::*;
|
|
||||||
use actix_web::web::Data;
|
|
||||||
use actix_web_actors::ws;
|
|
||||||
use bytes::Bytes;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
pub trait MessageReceiver: Send + Sync {
|
|
||||||
fn receive(&self, data: WSClientData);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Default)]
|
|
||||||
pub struct MessageReceivers {
|
|
||||||
inner: HashMap<u8, Arc<dyn MessageReceiver>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MessageReceivers {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
MessageReceivers::default()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn insert(&mut self, channel: u8, receiver: Arc<dyn MessageReceiver>) {
|
|
||||||
self.inner.insert(channel, receiver);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get(&self, source: u8) -> Option<&Arc<dyn MessageReceiver>> {
|
|
||||||
self.inner.get(&source)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
pub struct WSClientData {
|
|
||||||
pub(crate) socket: Socket,
|
|
||||||
pub(crate) detail: MessageDetail,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct WSClient {
|
|
||||||
user: Arc<LoggedUser>,
|
|
||||||
server: Addr<WSServer>,
|
|
||||||
msg_receivers: Data<MessageReceivers>,
|
|
||||||
hb: Instant,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl WSClient {
|
|
||||||
pub fn new(
|
|
||||||
user: LoggedUser,
|
|
||||||
server: Addr<WSServer>,
|
|
||||||
msg_receivers: Data<MessageReceivers>,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
user: Arc::new(user),
|
|
||||||
server,
|
|
||||||
msg_receivers,
|
|
||||||
hb: Instant::now(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn hb(&self, ctx: &mut ws::WebsocketContext<Self>) {
|
|
||||||
ctx.run_interval(HEARTBEAT_INTERVAL, |client, ctx| {
|
|
||||||
if Instant::now().duration_since(client.hb) > PING_TIMEOUT {
|
|
||||||
client.server.do_send(Disconnect {
|
|
||||||
user: client.user.clone(),
|
|
||||||
});
|
|
||||||
ctx.stop();
|
|
||||||
} else {
|
|
||||||
ctx.ping(b"");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
fn handle_binary_message(&self, bytes: Bytes, socket: Socket) {
|
|
||||||
let MessagePayload { channel, detail } = MessagePayload::from_bytes(&bytes);
|
|
||||||
match self.msg_receivers.get(channel) {
|
|
||||||
None => {
|
|
||||||
tracing::error!("Can't find the receiver for {:?}", channel);
|
|
||||||
}
|
|
||||||
Some(handler) => {
|
|
||||||
let client_data = WSClientData { socket, detail };
|
|
||||||
handler.receive(client_data);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for WSClient {
|
|
||||||
fn handle(&mut self, msg: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
|
|
||||||
match msg {
|
|
||||||
Ok(Ping(msg)) => {
|
|
||||||
self.hb = Instant::now();
|
|
||||||
ctx.pong(&msg);
|
|
||||||
}
|
|
||||||
Ok(Pong(_msg)) => {
|
|
||||||
// tracing::trace!("Receive {} pong {:?}", &self.session_id, &msg);
|
|
||||||
self.hb = Instant::now();
|
|
||||||
}
|
|
||||||
Ok(Binary(bytes)) => {
|
|
||||||
let socket = ctx.address().recipient();
|
|
||||||
self.handle_binary_message(bytes, socket);
|
|
||||||
}
|
|
||||||
Ok(Text(_)) => {
|
|
||||||
tracing::warn!("Receive unexpected text message");
|
|
||||||
}
|
|
||||||
Ok(Close(reason)) => {
|
|
||||||
ctx.close(reason);
|
|
||||||
ctx.stop();
|
|
||||||
}
|
|
||||||
Ok(ws::Message::Continuation(_)) => {}
|
|
||||||
Ok(ws::Message::Nop) => {}
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!("WebSocketStream protocol error {:?}", e);
|
|
||||||
ctx.stop();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Handler<WebSocketMessage> for WSClient {
|
|
||||||
type Result = ();
|
|
||||||
|
|
||||||
fn handle(&mut self, msg: WebSocketMessage, ctx: &mut Self::Context) {
|
|
||||||
ctx.binary(msg.0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Actor for WSClient {
|
|
||||||
type Context = ws::WebsocketContext<Self>;
|
|
||||||
|
|
||||||
fn started(&mut self, ctx: &mut Self::Context) {
|
|
||||||
self.hb(ctx);
|
|
||||||
let socket = ctx.address().recipient();
|
|
||||||
let connect = Connect {
|
|
||||||
socket,
|
|
||||||
user: self.user.clone(),
|
|
||||||
};
|
|
||||||
self.server
|
|
||||||
.send(connect)
|
|
||||||
.into_actor(self)
|
|
||||||
.then(|res, _client, _ctx| {
|
|
||||||
match res {
|
|
||||||
Ok(Ok(_)) => tracing::trace!("Send connect message to server success"),
|
|
||||||
Ok(Err(e)) => tracing::error!("Send connect message to server failed: {:?}", e),
|
|
||||||
Err(e) => tracing::error!("Send connect message to server failed: {:?}", e),
|
|
||||||
}
|
|
||||||
fut::ready(())
|
|
||||||
})
|
|
||||||
.wait(ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn stopping(&mut self, _: &mut Self::Context) -> Running {
|
|
||||||
self.server.do_send(Disconnect {
|
|
||||||
user: self.user.clone(),
|
|
||||||
});
|
|
||||||
|
|
||||||
Running::Stop
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,92 +0,0 @@
|
||||||
use crate::component::auth::LoggedUser;
|
|
||||||
use actix::{Message, Recipient};
|
|
||||||
use bytes::Bytes;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::fmt::Formatter;
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
pub type Socket = Recipient<WebSocketMessage>;
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone, Hash, PartialEq, Eq)]
|
|
||||||
pub struct WSSessionId(pub String);
|
|
||||||
|
|
||||||
impl<T: AsRef<str>> std::convert::From<T> for WSSessionId {
|
|
||||||
fn from(s: T) -> Self {
|
|
||||||
WSSessionId(s.as_ref().to_owned())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Display for WSSessionId {
|
|
||||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
||||||
let desc = &self.0.to_string();
|
|
||||||
f.write_str(desc)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct Session {
|
|
||||||
pub user: Arc<LoggedUser>,
|
|
||||||
pub socket: Socket,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::convert::From<Connect> for Session {
|
|
||||||
fn from(c: Connect) -> Self {
|
|
||||||
Self {
|
|
||||||
user: c.user,
|
|
||||||
socket: c.socket,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Message, Clone)]
|
|
||||||
#[rtype(result = "Result<(), WSError>")]
|
|
||||||
pub struct Connect {
|
|
||||||
pub socket: Socket,
|
|
||||||
pub user: Arc<LoggedUser>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Message, Clone)]
|
|
||||||
#[rtype(result = "Result<(), WSError>")]
|
|
||||||
pub struct Disconnect {
|
|
||||||
pub user: Arc<LoggedUser>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Message, Clone)]
|
|
||||||
#[rtype(result = "()")]
|
|
||||||
pub struct WebSocketMessage(pub Bytes);
|
|
||||||
|
|
||||||
impl std::ops::Deref for WebSocketMessage {
|
|
||||||
type Target = Bytes;
|
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
|
||||||
&self.0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
pub struct MessagePayload {
|
|
||||||
pub(crate) channel: u8,
|
|
||||||
pub(crate) detail: MessageDetail,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MessagePayload {
|
|
||||||
pub fn from_bytes<T: AsRef<[u8]>>(bytes: T) -> Self {
|
|
||||||
serde_json::from_slice(bytes.as_ref()).unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
#[serde(tag = "type")]
|
|
||||||
pub enum MessageDetail {
|
|
||||||
Document(MessageContent),
|
|
||||||
Database(MessageContent),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
pub struct MessageContent {
|
|
||||||
content: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum WSError {
|
|
||||||
Internal,
|
|
||||||
}
|
|
|
@ -1,11 +0,0 @@
|
||||||
use std::time::Duration;
|
|
||||||
|
|
||||||
mod client;
|
|
||||||
mod entities;
|
|
||||||
mod server;
|
|
||||||
|
|
||||||
pub use client::*;
|
|
||||||
pub use server::WSServer;
|
|
||||||
|
|
||||||
pub(crate) const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(8);
|
|
||||||
pub(crate) const PING_TIMEOUT: Duration = Duration::from_secs(60);
|
|
|
@ -1,44 +0,0 @@
|
||||||
use crate::component::ws::entities::{Connect, Disconnect, WSError, WebSocketMessage};
|
|
||||||
use actix::{Actor, Context, Handler};
|
|
||||||
|
|
||||||
#[derive(Default)]
|
|
||||||
pub struct WSServer {}
|
|
||||||
|
|
||||||
impl WSServer {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
WSServer::default()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn send(&self, _msg: WebSocketMessage) {}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Actor for WSServer {
|
|
||||||
type Context = Context<Self>;
|
|
||||||
fn started(&mut self, _ctx: &mut Self::Context) {}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Handler<Connect> for WSServer {
|
|
||||||
type Result = Result<(), WSError>;
|
|
||||||
fn handle(&mut self, _msg: Connect, _ctx: &mut Context<Self>) -> Self::Result {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Handler<Disconnect> for WSServer {
|
|
||||||
type Result = Result<(), WSError>;
|
|
||||||
fn handle(&mut self, _msg: Disconnect, _: &mut Context<Self>) -> Self::Result {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Handler<WebSocketMessage> for WSServer {
|
|
||||||
type Result = ();
|
|
||||||
|
|
||||||
fn handle(&mut self, _msg: WebSocketMessage, _ctx: &mut Context<Self>) -> Self::Result {}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl actix::Supervised for WSServer {
|
|
||||||
fn restarting(&mut self, _ctx: &mut Context<WSServer>) {
|
|
||||||
tracing::warn!("restarting");
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -3,12 +3,13 @@ use secrecy::Secret;
|
||||||
use serde_aux::field_attributes::deserialize_number_from_string;
|
use serde_aux::field_attributes::deserialize_number_from_string;
|
||||||
use sqlx::postgres::{PgConnectOptions, PgSslMode};
|
use sqlx::postgres::{PgConnectOptions, PgSslMode};
|
||||||
use std::convert::{TryFrom, TryInto};
|
use std::convert::{TryFrom, TryInto};
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
#[derive(serde::Deserialize, Clone, Debug)]
|
#[derive(serde::Deserialize, Clone, Debug)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub database: DatabaseSetting,
|
pub database: DatabaseSetting,
|
||||||
pub application: ApplicationSettings,
|
pub application: ApplicationSettings,
|
||||||
pub redis_uri: Secret<String>,
|
pub redis_uri: Secret<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// We are using 127.0.0.1 as our host in address, we are instructing our
|
// We are using 127.0.0.1 as our host in address, we are instructing our
|
||||||
|
@ -21,73 +22,78 @@ pub struct Config {
|
||||||
//
|
//
|
||||||
#[derive(serde::Deserialize, Clone, Debug)]
|
#[derive(serde::Deserialize, Clone, Debug)]
|
||||||
pub struct ApplicationSettings {
|
pub struct ApplicationSettings {
|
||||||
#[serde(deserialize_with = "deserialize_number_from_string")]
|
#[serde(deserialize_with = "deserialize_number_from_string")]
|
||||||
pub port: u16,
|
pub port: u16,
|
||||||
pub host: String,
|
pub host: String,
|
||||||
pub server_key: Secret<String>,
|
pub data_dir: PathBuf,
|
||||||
pub tls_config: Option<TlsConfig>,
|
pub server_key: Secret<String>,
|
||||||
|
pub tls_config: Option<TlsConfig>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ApplicationSettings {
|
impl ApplicationSettings {
|
||||||
pub fn use_https(&self) -> bool {
|
pub fn use_https(&self) -> bool {
|
||||||
match &self.tls_config {
|
match &self.tls_config {
|
||||||
None => false,
|
None => false,
|
||||||
Some(config) => match config {
|
Some(config) => match config {
|
||||||
TlsConfig::NoTls => false,
|
TlsConfig::NoTls => false,
|
||||||
TlsConfig::SelfSigned => true,
|
TlsConfig::SelfSigned => true,
|
||||||
},
|
},
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rocksdb_db_dir(&self) -> PathBuf {
|
||||||
|
self.data_dir.join("rocksdb")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(serde::Deserialize, Clone, Debug)]
|
#[derive(serde::Deserialize, Clone, Debug)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub enum TlsConfig {
|
pub enum TlsConfig {
|
||||||
NoTls,
|
NoTls,
|
||||||
SelfSigned,
|
SelfSigned,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(serde::Deserialize, Clone, Debug)]
|
#[derive(serde::Deserialize, Clone, Debug)]
|
||||||
pub struct DatabaseSetting {
|
pub struct DatabaseSetting {
|
||||||
pub username: String,
|
pub username: String,
|
||||||
pub password: String,
|
pub password: String,
|
||||||
#[serde(deserialize_with = "deserialize_number_from_string")]
|
#[serde(deserialize_with = "deserialize_number_from_string")]
|
||||||
pub port: u16,
|
pub port: u16,
|
||||||
pub host: String,
|
pub host: String,
|
||||||
pub database_name: String,
|
pub database_name: String,
|
||||||
pub require_ssl: bool,
|
pub require_ssl: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DatabaseSetting {
|
impl DatabaseSetting {
|
||||||
pub fn without_db(&self) -> PgConnectOptions {
|
pub fn without_db(&self) -> PgConnectOptions {
|
||||||
let ssl_mode = if self.require_ssl {
|
let ssl_mode = if self.require_ssl {
|
||||||
PgSslMode::Require
|
PgSslMode::Require
|
||||||
} else {
|
} else {
|
||||||
PgSslMode::Prefer
|
PgSslMode::Prefer
|
||||||
};
|
};
|
||||||
PgConnectOptions::new()
|
PgConnectOptions::new()
|
||||||
.host(&self.host)
|
.host(&self.host)
|
||||||
.username(&self.username)
|
.username(&self.username)
|
||||||
.password(&self.password)
|
.password(&self.password)
|
||||||
.port(self.port)
|
.port(self.port)
|
||||||
.ssl_mode(ssl_mode)
|
.ssl_mode(ssl_mode)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_db(&self) -> PgConnectOptions {
|
pub fn with_db(&self) -> PgConnectOptions {
|
||||||
self.without_db().database(&self.database_name)
|
self.without_db().database(&self.database_name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_configuration() -> Result<Config, config::ConfigError> {
|
pub fn get_configuration() -> Result<Config, config::ConfigError> {
|
||||||
let base_path = std::env::current_dir().expect("Failed to determine the current directory");
|
let base_path = std::env::current_dir().expect("Failed to determine the current directory");
|
||||||
let configuration_dir = base_path.join("configuration");
|
let configuration_dir = base_path.join("configuration");
|
||||||
|
|
||||||
let environment: Environment = std::env::var("APP_ENVIRONMENT")
|
let environment: Environment = std::env::var("APP_ENVIRONMENT")
|
||||||
.unwrap_or_else(|_| "local".into())
|
.unwrap_or_else(|_| "local".into())
|
||||||
.try_into()
|
.try_into()
|
||||||
.expect("Failed to parse APP_ENVIRONMENT.");
|
.expect("Failed to parse APP_ENVIRONMENT.");
|
||||||
|
|
||||||
let builder = InnerConfig::builder()
|
let builder = InnerConfig::builder()
|
||||||
.set_default("default", "1")?
|
.set_default("default", "1")?
|
||||||
.add_source(
|
.add_source(
|
||||||
config::File::from(configuration_dir.join("base"))
|
config::File::from(configuration_dir.join("base"))
|
||||||
|
@ -104,36 +110,36 @@ pub fn get_configuration() -> Result<Config, config::ConfigError> {
|
||||||
// `Settings.application.port`
|
// `Settings.application.port`
|
||||||
.add_source(config::Environment::with_prefix("app").separator("__"));
|
.add_source(config::Environment::with_prefix("app").separator("__"));
|
||||||
|
|
||||||
let config = builder.build()?;
|
let config = builder.build()?;
|
||||||
config.try_deserialize()
|
config.try_deserialize()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The possible runtime environment for our application.
|
/// The possible runtime environment for our application.
|
||||||
pub enum Environment {
|
pub enum Environment {
|
||||||
Local,
|
Local,
|
||||||
Production,
|
Production,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Environment {
|
impl Environment {
|
||||||
pub fn as_str(&self) -> &'static str {
|
pub fn as_str(&self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
Environment::Local => "local",
|
Environment::Local => "local",
|
||||||
Environment::Production => "production",
|
Environment::Production => "production",
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TryFrom<String> for Environment {
|
impl TryFrom<String> for Environment {
|
||||||
type Error = String;
|
type Error = String;
|
||||||
|
|
||||||
fn try_from(s: String) -> Result<Self, Self::Error> {
|
fn try_from(s: String) -> Result<Self, Self::Error> {
|
||||||
match s.to_lowercase().as_str() {
|
match s.to_lowercase().as_str() {
|
||||||
"local" => Ok(Self::Local),
|
"local" => Ok(Self::Local),
|
||||||
"production" => Ok(Self::Production),
|
"production" => Ok(Self::Production),
|
||||||
other => Err(format!(
|
other => Err(format!(
|
||||||
"{} is not a supported environment. Use either `local` or `production`.",
|
"{} is not a supported environment. Use either `local` or `production`.",
|
||||||
other
|
other
|
||||||
)),
|
)),
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
use std::env;
|
use std::env;
|
||||||
|
|
||||||
pub fn domain() -> String {
|
pub fn domain() -> String {
|
||||||
env::var("DOMAIN").unwrap_or_else(|_| "localhost".to_string())
|
env::var("DOMAIN").unwrap_or_else(|_| "localhost".to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn jwt_secret() -> String {
|
pub fn jwt_secret() -> String {
|
||||||
env::var("JWT_SECRET").unwrap_or_else(|_| "my secret".into())
|
env::var("JWT_SECRET").unwrap_or_else(|_| "my secret".into())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn secret() -> String {
|
pub fn secret() -> String {
|
||||||
env::var("SECRET_KEY").unwrap_or_else(|_| "0123".repeat(8))
|
env::var("SECRET_KEY").unwrap_or_else(|_| "0123".repeat(8))
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,44 +4,44 @@ use validator::validate_email;
|
||||||
pub struct UserEmail(pub String);
|
pub struct UserEmail(pub String);
|
||||||
|
|
||||||
impl UserEmail {
|
impl UserEmail {
|
||||||
pub fn parse(s: String) -> Result<UserEmail, String> {
|
pub fn parse(s: String) -> Result<UserEmail, String> {
|
||||||
if s.trim().is_empty() {
|
if s.trim().is_empty() {
|
||||||
return Err("Email can not be empty or whitespace".to_string());
|
return Err("Email can not be empty or whitespace".to_string());
|
||||||
}
|
|
||||||
|
|
||||||
if validate_email(&s) {
|
|
||||||
Ok(Self(s))
|
|
||||||
} else {
|
|
||||||
Err("Invalid email".to_string())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if validate_email(&s) {
|
||||||
|
Ok(Self(s))
|
||||||
|
} else {
|
||||||
|
Err("Invalid email".to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AsRef<str> for UserEmail {
|
impl AsRef<str> for UserEmail {
|
||||||
fn as_ref(&self) -> &str {
|
fn as_ref(&self) -> &str {
|
||||||
&self.0
|
&self.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn empty_string_is_rejected() {
|
fn empty_string_is_rejected() {
|
||||||
let email = "".to_string();
|
let email = "".to_string();
|
||||||
assert!(UserEmail::parse(email).is_err());
|
assert!(UserEmail::parse(email).is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn email_missing_at_symbol_is_rejected() {
|
fn email_missing_at_symbol_is_rejected() {
|
||||||
let email = "helloworld.com".to_string();
|
let email = "helloworld.com".to_string();
|
||||||
assert!(UserEmail::parse(email).is_err());
|
assert!(UserEmail::parse(email).is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn email_missing_subject_is_rejected() {
|
fn email_missing_subject_is_rejected() {
|
||||||
let email = "@domain.com".to_string();
|
let email = "@domain.com".to_string();
|
||||||
assert!(UserEmail::parse(email).is_err());
|
assert!(UserEmail::parse(email).is_err());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,78 +4,78 @@ use unicode_segmentation::UnicodeSegmentation;
|
||||||
pub struct UserName(pub String);
|
pub struct UserName(pub String);
|
||||||
|
|
||||||
impl UserName {
|
impl UserName {
|
||||||
pub fn parse(s: String) -> Result<UserName, String> {
|
pub fn parse(s: String) -> Result<UserName, String> {
|
||||||
let is_empty_or_whitespace = s.trim().is_empty();
|
let is_empty_or_whitespace = s.trim().is_empty();
|
||||||
if is_empty_or_whitespace {
|
if is_empty_or_whitespace {
|
||||||
return Err("User name can not be empty or whitespace".to_string());
|
return Err("User name can not be empty or whitespace".to_string());
|
||||||
}
|
|
||||||
// A grapheme is defined by the Unicode standard as a "user-perceived"
|
|
||||||
// character: `å` is a single grapheme, but it is composed of two characters
|
|
||||||
// (`a` and `̊`).
|
|
||||||
//
|
|
||||||
// `graphemes` returns an iterator over the graphemes in the input `s`.
|
|
||||||
// `true` specifies that we want to use the extended grapheme definition set,
|
|
||||||
// the recommended one.
|
|
||||||
let is_too_long = s.graphemes(true).count() > 256;
|
|
||||||
if is_too_long {
|
|
||||||
return Err("User name is too long".to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
let forbidden_characters = ['/', '(', ')', '"', '<', '>', '\\', '{', '}'];
|
|
||||||
let contains_forbidden_characters = s.chars().any(|g| forbidden_characters.contains(&g));
|
|
||||||
|
|
||||||
if contains_forbidden_characters {
|
|
||||||
return Err("User name contains invalid characters".to_string());
|
|
||||||
}
|
|
||||||
Ok(Self(s))
|
|
||||||
}
|
}
|
||||||
|
// A grapheme is defined by the Unicode standard as a "user-perceived"
|
||||||
|
// character: `å` is a single grapheme, but it is composed of two characters
|
||||||
|
// (`a` and `̊`).
|
||||||
|
//
|
||||||
|
// `graphemes` returns an iterator over the graphemes in the input `s`.
|
||||||
|
// `true` specifies that we want to use the extended grapheme definition set,
|
||||||
|
// the recommended one.
|
||||||
|
let is_too_long = s.graphemes(true).count() > 256;
|
||||||
|
if is_too_long {
|
||||||
|
return Err("User name is too long".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
let forbidden_characters = ['/', '(', ')', '"', '<', '>', '\\', '{', '}'];
|
||||||
|
let contains_forbidden_characters = s.chars().any(|g| forbidden_characters.contains(&g));
|
||||||
|
|
||||||
|
if contains_forbidden_characters {
|
||||||
|
return Err("User name contains invalid characters".to_string());
|
||||||
|
}
|
||||||
|
Ok(Self(s))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AsRef<str> for UserName {
|
impl AsRef<str> for UserName {
|
||||||
fn as_ref(&self) -> &str {
|
fn as_ref(&self) -> &str {
|
||||||
&self.0
|
&self.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::UserName;
|
use super::UserName;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn a_256_grapheme_long_name_is_valid() {
|
fn a_256_grapheme_long_name_is_valid() {
|
||||||
let name = "a̐".repeat(256);
|
let name = "a̐".repeat(256);
|
||||||
assert!(UserName::parse(name).is_ok());
|
assert!(UserName::parse(name).is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn a_name_longer_than_256_graphemes_is_rejected() {
|
fn a_name_longer_than_256_graphemes_is_rejected() {
|
||||||
let name = "a".repeat(257);
|
let name = "a".repeat(257);
|
||||||
assert!(UserName::parse(name).is_err());
|
assert!(UserName::parse(name).is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn whitespace_only_names_are_rejected() {
|
fn whitespace_only_names_are_rejected() {
|
||||||
let name = " ".to_string();
|
let name = " ".to_string();
|
||||||
assert!(UserName::parse(name).is_err());
|
assert!(UserName::parse(name).is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn empty_string_is_rejected() {
|
fn empty_string_is_rejected() {
|
||||||
let name = "".to_string();
|
let name = "".to_string();
|
||||||
assert!(UserName::parse(name).is_err());
|
assert!(UserName::parse(name).is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn names_containing_an_invalid_character_are_rejected() {
|
fn names_containing_an_invalid_character_are_rejected() {
|
||||||
for name in &['/', '(', ')', '"', '<', '>', '\\', '{', '}'] {
|
for name in &['/', '(', ')', '"', '<', '>', '\\', '{', '}'] {
|
||||||
let name = name.to_string();
|
let name = name.to_string();
|
||||||
assert!(UserName::parse(name).is_err());
|
assert!(UserName::parse(name).is_err());
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn a_valid_name_is_parsed_successfully() {
|
fn a_valid_name_is_parsed_successfully() {
|
||||||
let name = "nathan".to_string();
|
let name = "nathan".to_string();
|
||||||
assert!(UserName::parse(name).is_ok());
|
assert!(UserName::parse(name).is_ok());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,33 +6,33 @@ use unicode_segmentation::UnicodeSegmentation;
|
||||||
pub struct UserPassword(pub String);
|
pub struct UserPassword(pub String);
|
||||||
|
|
||||||
impl UserPassword {
|
impl UserPassword {
|
||||||
pub fn parse(s: String) -> Result<UserPassword, String> {
|
pub fn parse(s: String) -> Result<UserPassword, String> {
|
||||||
if s.trim().is_empty() {
|
if s.trim().is_empty() {
|
||||||
return Err("User password can not be empty or whitespace".to_owned());
|
return Err("User password can not be empty or whitespace".to_owned());
|
||||||
}
|
|
||||||
|
|
||||||
if s.graphemes(true).count() > 100 {
|
|
||||||
return Err("Password is too long".to_owned());
|
|
||||||
}
|
|
||||||
|
|
||||||
let forbidden_characters = ['/', '(', ')', '"', '<', '>', '\\', '{', '}'];
|
|
||||||
let contains_forbidden_characters = s.chars().any(|g| forbidden_characters.contains(&g));
|
|
||||||
if contains_forbidden_characters {
|
|
||||||
return Err("Password contains invalid characters".to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
if !validate_password(&s) {
|
|
||||||
return Err("Password format invalid".to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Self(s))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.graphemes(true).count() > 100 {
|
||||||
|
return Err("Password is too long".to_owned());
|
||||||
|
}
|
||||||
|
|
||||||
|
let forbidden_characters = ['/', '(', ')', '"', '<', '>', '\\', '{', '}'];
|
||||||
|
let contains_forbidden_characters = s.chars().any(|g| forbidden_characters.contains(&g));
|
||||||
|
if contains_forbidden_characters {
|
||||||
|
return Err("Password contains invalid characters".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
if !validate_password(&s) {
|
||||||
|
return Err("Password format invalid".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self(s))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AsRef<str> for UserPassword {
|
impl AsRef<str> for UserPassword {
|
||||||
fn as_ref(&self) -> &str {
|
fn as_ref(&self) -> &str {
|
||||||
&self.0
|
&self.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
|
@ -53,11 +53,11 @@ lazy_static! {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn validate_password(password: &str) -> bool {
|
pub fn validate_password(password: &str) -> bool {
|
||||||
match PASSWORD.is_match(password) {
|
match PASSWORD.is_match(password) {
|
||||||
Ok(is_match) => is_match,
|
Ok(is_match) => is_match,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!("validate_password fail: {:?}", e);
|
tracing::error!("validate_password fail: {:?}", e);
|
||||||
false
|
false
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
18
src/main.rs
18
src/main.rs
|
@ -4,13 +4,17 @@ use appflowy_server::telemetry::{get_subscriber, init_subscriber};
|
||||||
|
|
||||||
#[actix_web::main]
|
#[actix_web::main]
|
||||||
async fn main() -> anyhow::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
let subscriber = get_subscriber("appflowy_server".into(), "info".into(), std::io::stdout);
|
let subscriber = get_subscriber(
|
||||||
init_subscriber(subscriber);
|
"appflowy_server".to_string(),
|
||||||
|
"info".to_string(),
|
||||||
|
std::io::stdout,
|
||||||
|
);
|
||||||
|
init_subscriber(subscriber);
|
||||||
|
|
||||||
let configuration = get_configuration().expect("Failed to read configuration.");
|
let configuration = get_configuration().expect("Failed to read configuration.");
|
||||||
let state = init_state(&configuration).await;
|
let state = init_state(&configuration).await;
|
||||||
let application = Application::build(configuration, state).await?;
|
let application = Application::build(configuration, state).await?;
|
||||||
application.run_until_stopped().await?;
|
application.run_until_stopped().await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,11 +6,11 @@ use actix_web::http;
|
||||||
// http://www.ruanyifeng.com/blog/2016/04/cors.html
|
// http://www.ruanyifeng.com/blog/2016/04/cors.html
|
||||||
// Cors short for Cross-Origin Resource Sharing.
|
// Cors short for Cross-Origin Resource Sharing.
|
||||||
pub fn default_cors() -> Cors {
|
pub fn default_cors() -> Cors {
|
||||||
Cors::default() // allowed_origin return access-control-allow-origin: * by default
|
Cors::default() // allowed_origin return access-control-allow-origin: * by default
|
||||||
// .allowed_origin("http://127.0.0.1:8080")
|
.allow_any_origin()
|
||||||
.send_wildcard()
|
.send_wildcard()
|
||||||
.allowed_methods(vec!["GET", "POST", "PUT", "DELETE"])
|
.allowed_methods(vec!["GET", "POST", "PUT", "DELETE"])
|
||||||
.allowed_headers(vec![http::header::ACCEPT])
|
.allowed_headers(vec![http::header::ACCEPT])
|
||||||
.allowed_header(http::header::CONTENT_TYPE)
|
.allowed_header(http::header::CONTENT_TYPE)
|
||||||
.max_age(3600)
|
.max_age(3600)
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,26 +5,26 @@ pub const CA_CRT: &str = include_str!("../cert/cert.pem");
|
||||||
pub const CA_KEY: &str = include_str!("../cert/key.pem");
|
pub const CA_KEY: &str = include_str!("../cert/key.pem");
|
||||||
|
|
||||||
pub fn create_self_signed_certificate() -> Result<(Secret<String>, Secret<String>), RcgenError> {
|
pub fn create_self_signed_certificate() -> Result<(Secret<String>, Secret<String>), RcgenError> {
|
||||||
let key = KeyPair::from_pem(CA_KEY)?;
|
let key = KeyPair::from_pem(CA_KEY)?;
|
||||||
let params = CertificateParams::from_ca_cert_pem(CA_CRT, key)?;
|
let params = CertificateParams::from_ca_cert_pem(CA_CRT, key)?;
|
||||||
let ca_cert = Certificate::from_params(params)?;
|
let ca_cert = Certificate::from_params(params)?;
|
||||||
|
|
||||||
let mut params = CertificateParams::default();
|
let mut params = CertificateParams::default();
|
||||||
params
|
params
|
||||||
.subject_alt_names
|
.subject_alt_names
|
||||||
.push(SanType::IpAddress("127.0.0.1".parse().unwrap()));
|
.push(SanType::IpAddress("127.0.0.1".parse().unwrap()));
|
||||||
params
|
params
|
||||||
.subject_alt_names
|
.subject_alt_names
|
||||||
.push(SanType::IpAddress("0.0.0.0".parse().unwrap()));
|
.push(SanType::IpAddress("0.0.0.0".parse().unwrap()));
|
||||||
params
|
params
|
||||||
.subject_alt_names
|
.subject_alt_names
|
||||||
.push(SanType::DnsName("localhost".to_string()));
|
.push(SanType::DnsName("localhost".to_string()));
|
||||||
|
|
||||||
// Generate a certificate that's valid for:
|
// Generate a certificate that's valid for:
|
||||||
// 1. localhost
|
// 1. localhost
|
||||||
// 2. 127.0.0.1
|
// 2. 127.0.0.1
|
||||||
let gen_cert = Certificate::from_params(params)?;
|
let gen_cert = Certificate::from_params(params)?;
|
||||||
let server_crt = Secret::new(gen_cert.serialize_pem_with_signer(&ca_cert)?);
|
let server_crt = Secret::new(gen_cert.serialize_pem_with_signer(&ca_cert)?);
|
||||||
let server_key = Secret::new(gen_cert.serialize_private_key_pem());
|
let server_key = Secret::new(gen_cert.serialize_private_key_pem());
|
||||||
Ok((server_crt, server_key))
|
Ok((server_crt, server_key))
|
||||||
}
|
}
|
||||||
|
|
94
src/state.rs
94
src/state.rs
|
@ -1,6 +1,8 @@
|
||||||
use crate::component::auth::LoggedUser;
|
use crate::component::auth::LoggedUser;
|
||||||
use crate::config::config::Config;
|
use crate::config::config::Config;
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
|
use collab_persistence::kv::rocks_kv::RocksCollabDB;
|
||||||
|
|
||||||
use snowflake::Snowflake;
|
use snowflake::Snowflake;
|
||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
|
@ -9,70 +11,72 @@ use tokio::sync::RwLock;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct State {
|
pub struct State {
|
||||||
pub pg_pool: PgPool,
|
pub pg_pool: PgPool,
|
||||||
pub config: Arc<Config>,
|
pub rocksdb: Arc<RocksCollabDB>,
|
||||||
pub user: Arc<RwLock<UserCache>>,
|
pub config: Arc<Config>,
|
||||||
pub id_gen: Arc<RwLock<Snowflake>>,
|
pub user: Arc<RwLock<UserCache>>,
|
||||||
|
pub id_gen: Arc<RwLock<Snowflake>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl State {
|
impl State {
|
||||||
pub async fn load_users(_pool: &PgPool) {
|
pub async fn load_users(_pool: &PgPool) {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn next_user_id(&self) -> i64 {
|
pub async fn next_user_id(&self) -> i64 {
|
||||||
self.id_gen.write().await.next_id()
|
self.id_gen.write().await.next_id()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy)]
|
#[derive(Clone, Debug, Copy)]
|
||||||
enum AuthStatus {
|
enum AuthStatus {
|
||||||
Authorized(DateTime<Utc>),
|
Authorized(DateTime<Utc>),
|
||||||
NotAuthorized,
|
NotAuthorized,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const EXPIRED_DURATION_DAYS: i64 = 30;
|
pub const EXPIRED_DURATION_DAYS: i64 = 30;
|
||||||
|
|
||||||
#[derive(Debug, Default)]
|
#[derive(Debug, Default)]
|
||||||
pub struct UserCache {
|
pub struct UserCache {
|
||||||
// Keep track the user authentication state
|
// Keep track the user authentication state
|
||||||
user: BTreeMap<i64, AuthStatus>,
|
user: BTreeMap<i64, AuthStatus>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UserCache {
|
impl UserCache {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
UserCache::default()
|
UserCache::default()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_authorized(&self, user: &LoggedUser) -> bool {
|
pub fn is_authorized(&self, user: &LoggedUser) -> bool {
|
||||||
match self.user.get(user.expose_secret()) {
|
match self.user.get(user.expose_secret()) {
|
||||||
None => {
|
None => {
|
||||||
tracing::debug!("user not login yet or server was reboot");
|
tracing::debug!("user not login yet or server was reboot");
|
||||||
false
|
false
|
||||||
}
|
},
|
||||||
Some(status) => match *status {
|
Some(status) => match *status {
|
||||||
AuthStatus::Authorized(last_time) => {
|
AuthStatus::Authorized(last_time) => {
|
||||||
let current_time = Utc::now();
|
let current_time = Utc::now();
|
||||||
let days = (current_time - last_time).num_days();
|
let days = (current_time - last_time).num_days();
|
||||||
days < EXPIRED_DURATION_DAYS
|
days < EXPIRED_DURATION_DAYS
|
||||||
}
|
},
|
||||||
AuthStatus::NotAuthorized => {
|
AuthStatus::NotAuthorized => {
|
||||||
tracing::debug!("user logout already");
|
tracing::debug!("user logout already");
|
||||||
false
|
false
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn authorized(&mut self, user: LoggedUser) {
|
pub fn authorized(&mut self, user: LoggedUser) {
|
||||||
self.user.insert(
|
self.user.insert(
|
||||||
user.expose_secret().to_owned(),
|
user.expose_secret().to_owned(),
|
||||||
AuthStatus::Authorized(Utc::now()),
|
AuthStatus::Authorized(Utc::now()),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn unauthorized(&mut self, user: LoggedUser) {
|
pub fn unauthorized(&mut self, user: LoggedUser) {
|
||||||
self.user
|
self
|
||||||
.insert(user.expose_secret().to_owned(), AuthStatus::NotAuthorized);
|
.user
|
||||||
}
|
.insert(user.expose_secret().to_owned(), AuthStatus::NotAuthorized);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,39 +4,41 @@ use tracing::Subscriber;
|
||||||
use tracing_bunyan_formatter::{BunyanFormattingLayer, JsonStorageLayer};
|
use tracing_bunyan_formatter::{BunyanFormattingLayer, JsonStorageLayer};
|
||||||
use tracing_log::LogTracer;
|
use tracing_log::LogTracer;
|
||||||
use tracing_subscriber::fmt::MakeWriter;
|
use tracing_subscriber::fmt::MakeWriter;
|
||||||
use tracing_subscriber::{layer::SubscriberExt, EnvFilter, Registry};
|
use tracing_subscriber::{layer::SubscriberExt, EnvFilter};
|
||||||
|
|
||||||
/// Compose multiple layers into a `tracing`'s subscriber.
|
/// Compose multiple layers into a `tracing`'s subscriber.
|
||||||
pub fn get_subscriber<Sink>(
|
pub fn get_subscriber<Sink>(
|
||||||
name: String,
|
name: String,
|
||||||
env_filter: String,
|
env_filter: String,
|
||||||
sink: Sink,
|
sink: Sink,
|
||||||
) -> impl Subscriber + Sync + Send
|
) -> impl Subscriber + Sync + Send
|
||||||
where
|
where
|
||||||
Sink: for<'a> MakeWriter<'a> + Send + Sync + 'static,
|
Sink: for<'a> MakeWriter<'a> + Send + Sync + 'static,
|
||||||
{
|
{
|
||||||
let env_filter =
|
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(env_filter));
|
||||||
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(env_filter));
|
// let env_filter = EnvFilter::new(env_filter);
|
||||||
let formatting_layer = BunyanFormattingLayer::new(name, sink);
|
let formatting_layer = BunyanFormattingLayer::new(name, sink);
|
||||||
Registry::default()
|
tracing_subscriber::fmt()
|
||||||
.with(env_filter)
|
.with_ansi(true)
|
||||||
.with(JsonStorageLayer)
|
.finish()
|
||||||
.with(formatting_layer)
|
.with(env_filter)
|
||||||
|
.with(JsonStorageLayer)
|
||||||
|
.with(formatting_layer)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Register a subscriber as global default to process span data.
|
/// Register a subscriber as global default to process span data.
|
||||||
///
|
///
|
||||||
/// It should only be called once!
|
/// It should only be called once!
|
||||||
pub fn init_subscriber(subscriber: impl Subscriber + Sync + Send) {
|
pub fn init_subscriber(subscriber: impl Subscriber + Sync + Send) {
|
||||||
LogTracer::init().expect("Failed to set logger");
|
LogTracer::init().expect("Failed to set logger");
|
||||||
set_global_default(subscriber).expect("Failed to set subscriber");
|
set_global_default(subscriber).expect("Failed to set subscriber");
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn spawn_blocking_with_tracing<F, R>(f: F) -> JoinHandle<R>
|
pub fn spawn_blocking_with_tracing<F, R>(f: F) -> JoinHandle<R>
|
||||||
where
|
where
|
||||||
F: FnOnce() -> R + Send + 'static,
|
F: FnOnce() -> R + Send + 'static,
|
||||||
R: Send + 'static,
|
R: Send + 'static,
|
||||||
{
|
{
|
||||||
let current_span = tracing::Span::current();
|
let current_span = tracing::Span::current();
|
||||||
actix_web::rt::task::spawn_blocking(move || current_span.in_scope(f))
|
actix_web::rt::task::spawn_blocking(move || current_span.in_scope(f))
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,44 +1,44 @@
|
||||||
use crate::test_server::{spawn_server, TestUser};
|
use crate::util::{spawn_server, TestUser};
|
||||||
use actix_web::http::StatusCode;
|
use actix_web::http::StatusCode;
|
||||||
use appflowy_server::component::auth::LoginResponse;
|
use appflowy_server::component::auth::LoginResponse;
|
||||||
|
|
||||||
#[tokio::test]
|
#[actix_rt::test]
|
||||||
async fn login_success() {
|
async fn login_success() {
|
||||||
let server = spawn_server().await;
|
let server = spawn_server().await;
|
||||||
let test_user = TestUser::generate();
|
let test_user = TestUser::generate();
|
||||||
test_user.register(&server).await;
|
test_user.register(&server).await;
|
||||||
|
|
||||||
let http_resp = server.login(&test_user.email, &test_user.password).await;
|
let http_resp = server.login(&test_user.email, &test_user.password).await;
|
||||||
assert_eq!(http_resp.status(), StatusCode::OK);
|
assert_eq!(http_resp.status(), StatusCode::OK);
|
||||||
|
|
||||||
let bytes = http_resp.bytes().await.unwrap();
|
let bytes = http_resp.bytes().await.unwrap();
|
||||||
let response: LoginResponse = serde_json::from_slice(&bytes).unwrap();
|
let response: LoginResponse = serde_json::from_slice(&bytes).unwrap();
|
||||||
assert!(!response.token.is_empty())
|
assert!(!response.token.is_empty())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[actix_rt::test]
|
||||||
async fn login_with_empty_email() {
|
async fn login_with_empty_email() {
|
||||||
let server = spawn_server().await;
|
let server = spawn_server().await;
|
||||||
let test_user = TestUser::generate();
|
let test_user = TestUser::generate();
|
||||||
test_user.register(&server).await;
|
test_user.register(&server).await;
|
||||||
|
|
||||||
let http_resp = server.login("", &test_user.password).await;
|
let http_resp = server.login("", &test_user.password).await;
|
||||||
assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST);
|
assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[actix_rt::test]
|
||||||
async fn login_with_empty_password() {
|
async fn login_with_empty_password() {
|
||||||
let server = spawn_server().await;
|
let server = spawn_server().await;
|
||||||
let test_user = TestUser::generate();
|
let test_user = TestUser::generate();
|
||||||
test_user.register(&server).await;
|
test_user.register(&server).await;
|
||||||
|
|
||||||
let http_resp = server.login(&test_user.email, "").await;
|
let http_resp = server.login(&test_user.email, "").await;
|
||||||
assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST);
|
assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[actix_rt::test]
|
||||||
async fn login_with_unknown_user() {
|
async fn login_with_unknown_user() {
|
||||||
let server = spawn_server().await;
|
let server = spawn_server().await;
|
||||||
let http_resp = server.login("unknown@appflowy.io", "Abc@123!").await;
|
let http_resp = server.login("unknown@appflowy.io", "Abc@123!").await;
|
||||||
assert_eq!(http_resp.status(), StatusCode::UNAUTHORIZED);
|
assert_eq!(http_resp.status(), StatusCode::UNAUTHORIZED);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
mod login;
|
mod login;
|
||||||
mod password;
|
mod password;
|
||||||
mod register;
|
mod register;
|
||||||
mod test_server;
|
mod ws;
|
|
@ -1,53 +1,53 @@
|
||||||
use crate::test_server::{spawn_server, TestUser};
|
use crate::util::{spawn_server, TestUser};
|
||||||
use actix_web::http::StatusCode;
|
use actix_web::http::StatusCode;
|
||||||
|
|
||||||
#[tokio::test]
|
#[actix_rt::test]
|
||||||
async fn change_password_with_unmatched_password() {
|
async fn change_password_with_unmatched_password() {
|
||||||
let server = spawn_server().await;
|
let server = spawn_server().await;
|
||||||
let test_user = TestUser::generate();
|
let test_user = TestUser::generate();
|
||||||
let token = test_user.register(&server).await;
|
let token = test_user.register(&server).await;
|
||||||
|
|
||||||
let new_password = "HelloWorld@1a";
|
let new_password = "HelloWorld@1a";
|
||||||
let new_password_confirm = "HeloWorld@1a";
|
let new_password_confirm = "HeloWorld@1a";
|
||||||
let http_resp = server
|
let http_resp = server
|
||||||
.change_password(
|
.change_password(
|
||||||
token,
|
token,
|
||||||
&test_user.password,
|
&test_user.password,
|
||||||
new_password,
|
new_password,
|
||||||
new_password_confirm,
|
new_password_confirm,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST);
|
assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[actix_rt::test]
|
||||||
async fn login_fail_after_change_password() {
|
async fn login_fail_after_change_password() {
|
||||||
let server = spawn_server().await;
|
let server = spawn_server().await;
|
||||||
let test_user = TestUser::generate();
|
let test_user = TestUser::generate();
|
||||||
let token = test_user.register(&server).await;
|
let token = test_user.register(&server).await;
|
||||||
|
|
||||||
let new_password = "HelloWorld@1a";
|
let new_password = "HelloWorld@1a";
|
||||||
let http_resp = server
|
let http_resp = server
|
||||||
.change_password(token, &test_user.password, new_password, new_password)
|
.change_password(token, &test_user.password, new_password, new_password)
|
||||||
.await;
|
.await;
|
||||||
assert_eq!(http_resp.status(), StatusCode::OK);
|
assert_eq!(http_resp.status(), StatusCode::OK);
|
||||||
|
|
||||||
let http_resp = server.login(&test_user.email, &test_user.password).await;
|
let http_resp = server.login(&test_user.email, &test_user.password).await;
|
||||||
assert_eq!(http_resp.status(), StatusCode::UNAUTHORIZED);
|
assert_eq!(http_resp.status(), StatusCode::UNAUTHORIZED);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[actix_rt::test]
|
||||||
async fn login_success_with_new_password() {
|
async fn login_success_with_new_password() {
|
||||||
let server = spawn_server().await;
|
let server = spawn_server().await;
|
||||||
let test_user = TestUser::generate();
|
let test_user = TestUser::generate();
|
||||||
let token = test_user.register(&server).await;
|
let token = test_user.register(&server).await;
|
||||||
|
|
||||||
let new_password = "HelloWorld@1a";
|
let new_password = "HelloWorld@1a";
|
||||||
let http_resp = server
|
let http_resp = server
|
||||||
.change_password(token, &test_user.password, new_password, new_password)
|
.change_password(token, &test_user.password, new_password, new_password)
|
||||||
.await;
|
.await;
|
||||||
assert_eq!(http_resp.status(), StatusCode::OK);
|
assert_eq!(http_resp.status(), StatusCode::OK);
|
||||||
|
|
||||||
let http_resp = server.login(&test_user.email, new_password).await;
|
let http_resp = server.login(&test_user.email, new_password).await;
|
||||||
assert_eq!(http_resp.status(), StatusCode::OK);
|
assert_eq!(http_resp.status(), StatusCode::OK);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,53 +1,53 @@
|
||||||
use crate::test_server::{error_msg_from_resp, spawn_server};
|
use crate::util::{error_msg_from_resp, spawn_server};
|
||||||
use appflowy_server::component::auth::{InputParamsError, RegisterResponse};
|
use appflowy_server::component::auth::{InputParamsError, RegisterResponse};
|
||||||
use reqwest::StatusCode;
|
use reqwest::StatusCode;
|
||||||
|
|
||||||
#[tokio::test]
|
#[actix_rt::test]
|
||||||
// curl -X POST --url http://0.0.0.0:8000/api/user/register --header 'content-type: application/json' --data '{"name":"fake name", "email":"fake@appflowy.io", "password":"Fake@123"}'
|
// curl -X POST --url http://0.0.0.0:8000/api/user/register --header 'content-type: application/json' --data '{"name":"fake name", "email":"fake@appflowy.io", "password":"Fake@123"}'
|
||||||
async fn register_success() {
|
async fn register_success() {
|
||||||
let server = spawn_server().await;
|
let server = spawn_server().await;
|
||||||
let http_resp = server
|
let http_resp = server
|
||||||
.register("user 1", "fake@appflowy.io", "FakePassword!123")
|
.register("user 1", "fake@appflowy.io", "FakePassword!123")
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let bytes = http_resp.bytes().await.unwrap();
|
let bytes = http_resp.bytes().await.unwrap();
|
||||||
let response: RegisterResponse = serde_json::from_slice(&bytes).unwrap();
|
let response: RegisterResponse = serde_json::from_slice(&bytes).unwrap();
|
||||||
println!("{:?}", response);
|
println!("{:?}", response);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[actix_rt::test]
|
||||||
async fn register_with_invalid_password() {
|
async fn register_with_invalid_password() {
|
||||||
let server = spawn_server().await;
|
let server = spawn_server().await;
|
||||||
let http_resp = server.register("user 1", "fake@appflowy.io", "123").await;
|
let http_resp = server.register("user 1", "fake@appflowy.io", "123").await;
|
||||||
assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST);
|
assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
error_msg_from_resp(http_resp).await,
|
error_msg_from_resp(http_resp).await,
|
||||||
InputParamsError::InvalidPassword.to_string()
|
InputParamsError::InvalidPassword.to_string()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[actix_rt::test]
|
||||||
async fn register_with_invalid_name() {
|
async fn register_with_invalid_name() {
|
||||||
let server = spawn_server().await;
|
let server = spawn_server().await;
|
||||||
let name = "".to_string();
|
let name = "".to_string();
|
||||||
let http_resp = server
|
let http_resp = server
|
||||||
.register(&name, "fake@appflowy.io", "FakePassword!123")
|
.register(&name, "fake@appflowy.io", "FakePassword!123")
|
||||||
.await;
|
.await;
|
||||||
assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST);
|
assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
error_msg_from_resp(http_resp).await,
|
error_msg_from_resp(http_resp).await,
|
||||||
InputParamsError::InvalidName(name).to_string()
|
InputParamsError::InvalidName(name).to_string()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[actix_rt::test]
|
||||||
async fn register_with_invalid_email() {
|
async fn register_with_invalid_email() {
|
||||||
let server = spawn_server().await;
|
let server = spawn_server().await;
|
||||||
let email = "appflowy.io".to_string();
|
let email = "appflowy.io".to_string();
|
||||||
let http_resp = server.register("me", &email, "FakePassword!123").await;
|
let http_resp = server.register("me", &email, "FakePassword!123").await;
|
||||||
assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST);
|
assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
error_msg_from_resp(http_resp).await,
|
error_msg_from_resp(http_resp).await,
|
||||||
InputParamsError::InvalidEmail(email).to_string()
|
InputParamsError::InvalidEmail(email).to_string()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,189 +0,0 @@
|
||||||
use appflowy_server::application::{init_state, Application};
|
|
||||||
use appflowy_server::config::config::{get_configuration, DatabaseSetting, TlsConfig};
|
|
||||||
use appflowy_server::state::State;
|
|
||||||
use appflowy_server::telemetry::{get_subscriber, init_subscriber};
|
|
||||||
use once_cell::sync::Lazy;
|
|
||||||
use reqwest::Certificate;
|
|
||||||
|
|
||||||
use appflowy_server::component::auth::{RegisterResponse, HEADER_TOKEN};
|
|
||||||
use sqlx::types::Uuid;
|
|
||||||
use sqlx::{Connection, Executor, PgConnection, PgPool};
|
|
||||||
|
|
||||||
// Ensure that the `tracing` stack is only initialised once using `once_cell`
|
|
||||||
static TRACING: Lazy<()> = Lazy::new(|| {
|
|
||||||
let level = "info".to_string();
|
|
||||||
let mut filters = vec![];
|
|
||||||
filters.push(format!("appflowy_server={}", level));
|
|
||||||
filters.push(format!("hyper={}", level));
|
|
||||||
|
|
||||||
let subscriber_name = "test".to_string();
|
|
||||||
let subscriber = get_subscriber(subscriber_name, filters.join(","), std::io::stdout);
|
|
||||||
init_subscriber(subscriber);
|
|
||||||
});
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct TestServer {
|
|
||||||
pub state: State,
|
|
||||||
pub api_client: reqwest::Client,
|
|
||||||
pub address: String,
|
|
||||||
pub port: u16,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TestServer {
|
|
||||||
pub async fn register(&self, name: &str, email: &str, password: &str) -> reqwest::Response {
|
|
||||||
let payload = serde_json::json!({
|
|
||||||
"name": name,
|
|
||||||
"password": password,
|
|
||||||
"email": email
|
|
||||||
});
|
|
||||||
let url = format!("{}/api/user/register", self.address);
|
|
||||||
self.api_client
|
|
||||||
.post(&url)
|
|
||||||
.json(&payload)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.expect("Register failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn login(&self, email: &str, password: &str) -> reqwest::Response {
|
|
||||||
let payload = serde_json::json!({
|
|
||||||
"password": password,
|
|
||||||
"email": email
|
|
||||||
});
|
|
||||||
let url = format!("{}/api/user/login", self.address);
|
|
||||||
self.api_client
|
|
||||||
.post(&url)
|
|
||||||
.json(&payload)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.expect("Login failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn change_password(
|
|
||||||
&self,
|
|
||||||
token: String,
|
|
||||||
current_password: &str,
|
|
||||||
new_password: &str,
|
|
||||||
new_password_confirm: &str,
|
|
||||||
) -> reqwest::Response {
|
|
||||||
let payload = serde_json::json!({
|
|
||||||
"current_password": current_password,
|
|
||||||
"new_password": new_password,
|
|
||||||
"new_password_confirm": new_password_confirm
|
|
||||||
});
|
|
||||||
let url = format!("{}/api/user/password", self.address);
|
|
||||||
self.api_client
|
|
||||||
.post(&url)
|
|
||||||
.header(HEADER_TOKEN, token)
|
|
||||||
.json(&payload)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.expect("Change password failed")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn spawn_server() -> TestServer {
|
|
||||||
Lazy::force(&TRACING);
|
|
||||||
let database_name = Uuid::new_v4().to_string();
|
|
||||||
let config = {
|
|
||||||
let mut config = get_configuration().expect("Failed to read configuration.");
|
|
||||||
config.database.database_name = database_name.clone();
|
|
||||||
// Use a random OS port
|
|
||||||
config.application.port = 0;
|
|
||||||
config
|
|
||||||
};
|
|
||||||
|
|
||||||
let _ = configure_database(&config.database).await;
|
|
||||||
let state = init_state(&config).await;
|
|
||||||
let application = Application::build(config.clone(), state.clone())
|
|
||||||
.await
|
|
||||||
.expect("Failed to build application");
|
|
||||||
|
|
||||||
let port = application.port();
|
|
||||||
let _ = tokio::spawn(async {
|
|
||||||
let _ = application.run_until_stopped().await;
|
|
||||||
});
|
|
||||||
let mut builder = reqwest::Client::builder();
|
|
||||||
let mut address = format!("http://localhost:{}", port);
|
|
||||||
if config.application.use_https() {
|
|
||||||
address = format!("https://localhost:{}", port);
|
|
||||||
builder = builder.add_root_certificate(
|
|
||||||
Certificate::from_pem(include_bytes!("../../cert/cert.pem")).unwrap(),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
let api_client = builder
|
|
||||||
.add_root_certificate(Certificate::from_pem(include_bytes!("../../cert/cert.pem")).unwrap())
|
|
||||||
.redirect(reqwest::redirect::Policy::none())
|
|
||||||
.danger_accept_invalid_certs(true)
|
|
||||||
.cookie_store(true)
|
|
||||||
.no_proxy()
|
|
||||||
.build()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
TestServer {
|
|
||||||
state,
|
|
||||||
api_client,
|
|
||||||
address,
|
|
||||||
port,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn configure_database(config: &DatabaseSetting) -> PgPool {
|
|
||||||
// Create database
|
|
||||||
let mut connection = PgConnection::connect_with(&config.without_db())
|
|
||||||
.await
|
|
||||||
.expect("Failed to connect to Postgres");
|
|
||||||
connection
|
|
||||||
.execute(&*format!(r#"CREATE DATABASE "{}";"#, config.database_name))
|
|
||||||
.await
|
|
||||||
.expect("Failed to create database.");
|
|
||||||
|
|
||||||
// Migrate database
|
|
||||||
let connection_pool = PgPool::connect_with(config.with_db())
|
|
||||||
.await
|
|
||||||
.expect("Failed to connect to Postgres.");
|
|
||||||
|
|
||||||
sqlx::migrate!("./migrations")
|
|
||||||
.run(&connection_pool)
|
|
||||||
.await
|
|
||||||
.expect("Failed to migrate the database");
|
|
||||||
|
|
||||||
connection_pool
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(serde::Serialize)]
|
|
||||||
pub struct TestUser {
|
|
||||||
name: String,
|
|
||||||
pub email: String,
|
|
||||||
pub password: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TestUser {
|
|
||||||
pub fn generate() -> Self {
|
|
||||||
Self {
|
|
||||||
name: "Me".to_string(),
|
|
||||||
email: "me@appflowy.io".to_string(),
|
|
||||||
password: "Hello@AppFlowy123".to_string(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn register(&self, test_server: &TestServer) -> String {
|
|
||||||
let url = format!("{}/api/user/register", test_server.address);
|
|
||||||
let resp = test_server
|
|
||||||
.api_client
|
|
||||||
.post(&url)
|
|
||||||
.json(self)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.expect("Fail to register user");
|
|
||||||
|
|
||||||
let bytes = resp.bytes().await.unwrap();
|
|
||||||
let response: RegisterResponse = serde_json::from_slice(&bytes).unwrap();
|
|
||||||
response.token
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn error_msg_from_resp(resp: reqwest::Response) -> String {
|
|
||||||
let bytes = resp.bytes().await.unwrap();
|
|
||||||
String::from_utf8(bytes.to_vec()).unwrap()
|
|
||||||
}
|
|
13
tests/api/ws.rs
Normal file
13
tests/api/ws.rs
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
use crate::util::{spawn_server, TestUser};
|
||||||
|
use collab_client_ws::WSClient;
|
||||||
|
|
||||||
|
#[actix_rt::test]
|
||||||
|
async fn ws_conn_test() {
|
||||||
|
let server = spawn_server().await;
|
||||||
|
let test_user = TestUser::generate();
|
||||||
|
let token = test_user.register(&server).await;
|
||||||
|
|
||||||
|
let address = format!("{}/{}", server.ws_addr, token);
|
||||||
|
let client = WSClient::new(address, 100);
|
||||||
|
let _ = client.connect().await;
|
||||||
|
}
|
3
tests/main.rs
Normal file
3
tests/main.rs
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
mod api;
|
||||||
|
mod util;
|
||||||
|
mod ws;
|
2
tests/util/mod.rs
Normal file
2
tests/util/mod.rs
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
mod test_server;
|
||||||
|
pub use test_server::*;
|
232
tests/util/test_server.rs
Normal file
232
tests/util/test_server.rs
Normal file
|
@ -0,0 +1,232 @@
|
||||||
|
use appflowy_server::application::{init_state, Application};
|
||||||
|
use appflowy_server::config::config::{get_configuration, DatabaseSetting};
|
||||||
|
use appflowy_server::state::State;
|
||||||
|
use appflowy_server::telemetry::{get_subscriber, init_subscriber};
|
||||||
|
use once_cell::sync::Lazy;
|
||||||
|
use reqwest::Certificate;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
use appflowy_server::component::auth::{RegisterResponse, HEADER_TOKEN};
|
||||||
|
use sqlx::types::Uuid;
|
||||||
|
use sqlx::{Connection, Executor, PgConnection, PgPool};
|
||||||
|
|
||||||
|
// Ensure that the `tracing` stack is only initialised once using `once_cell`
|
||||||
|
static TRACING: Lazy<()> = Lazy::new(|| {
|
||||||
|
let level = "trace".to_string();
|
||||||
|
let mut filters = vec![];
|
||||||
|
filters.push(format!("appflowy_server={}", level));
|
||||||
|
filters.push(format!("collab_client_ws={}", level));
|
||||||
|
filters.push(format!("hyper={}", level));
|
||||||
|
filters.push(format!("actix_web={}", level));
|
||||||
|
|
||||||
|
let subscriber_name = "test".to_string();
|
||||||
|
let subscriber = get_subscriber(subscriber_name, filters.join(","), std::io::stdout);
|
||||||
|
init_subscriber(subscriber);
|
||||||
|
});
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct TestServer {
|
||||||
|
pub state: State,
|
||||||
|
pub api_client: reqwest::Client,
|
||||||
|
pub address: String,
|
||||||
|
pub port: u16,
|
||||||
|
pub ws_addr: String,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub cleaner: Cleaner,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TestServer {
|
||||||
|
pub async fn register(&self, name: &str, email: &str, password: &str) -> reqwest::Response {
|
||||||
|
let payload = serde_json::json!({
|
||||||
|
"name": name,
|
||||||
|
"password": password,
|
||||||
|
"email": email
|
||||||
|
});
|
||||||
|
let url = format!("{}/api/user/register", self.address);
|
||||||
|
self
|
||||||
|
.api_client
|
||||||
|
.post(&url)
|
||||||
|
.json(&payload)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.expect("Register failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn login(&self, email: &str, password: &str) -> reqwest::Response {
|
||||||
|
let payload = serde_json::json!({
|
||||||
|
"password": password,
|
||||||
|
"email": email
|
||||||
|
});
|
||||||
|
let url = format!("{}/api/user/login", self.address);
|
||||||
|
self
|
||||||
|
.api_client
|
||||||
|
.post(&url)
|
||||||
|
.json(&payload)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.expect("Login failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn change_password(
|
||||||
|
&self,
|
||||||
|
token: String,
|
||||||
|
current_password: &str,
|
||||||
|
new_password: &str,
|
||||||
|
new_password_confirm: &str,
|
||||||
|
) -> reqwest::Response {
|
||||||
|
let payload = serde_json::json!({
|
||||||
|
"current_password": current_password,
|
||||||
|
"new_password": new_password,
|
||||||
|
"new_password_confirm": new_password_confirm
|
||||||
|
});
|
||||||
|
let url = format!("{}/api/user/password", self.address);
|
||||||
|
self
|
||||||
|
.api_client
|
||||||
|
.post(&url)
|
||||||
|
.header(HEADER_TOKEN, token)
|
||||||
|
.json(&payload)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.expect("Change password failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn spawn_server() -> TestServer {
|
||||||
|
Lazy::force(&TRACING);
|
||||||
|
let database_name = Uuid::new_v4().to_string();
|
||||||
|
let config = {
|
||||||
|
let mut config = get_configuration().expect("Failed to read configuration.");
|
||||||
|
config.database.database_name = database_name.clone();
|
||||||
|
// Use a random OS port
|
||||||
|
config.application.port = 0;
|
||||||
|
config.application.data_dir = PathBuf::from(format!("./data/{}", database_name));
|
||||||
|
config
|
||||||
|
};
|
||||||
|
|
||||||
|
let _ = configure_database(&config.database).await;
|
||||||
|
let state = init_state(&config).await;
|
||||||
|
let application = Application::build(config.clone(), state.clone())
|
||||||
|
.await
|
||||||
|
.expect("Failed to build application");
|
||||||
|
|
||||||
|
let port = application.port();
|
||||||
|
let _ = tokio::spawn(async {
|
||||||
|
let _ = application.run_until_stopped().await;
|
||||||
|
});
|
||||||
|
let mut builder = reqwest::Client::builder();
|
||||||
|
let mut address = format!("http://localhost:{}", port);
|
||||||
|
let mut ws_addr = format!("ws://localhost:{}/ws", port);
|
||||||
|
if config.application.use_https() {
|
||||||
|
address = format!("https://localhost:{}", port);
|
||||||
|
ws_addr = format!("wss://localhost:{}/ws", port);
|
||||||
|
builder = builder
|
||||||
|
.add_root_certificate(Certificate::from_pem(include_bytes!("../../cert/cert.pem")).unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
let api_client = builder
|
||||||
|
.add_root_certificate(Certificate::from_pem(include_bytes!("../../cert/cert.pem")).unwrap())
|
||||||
|
.redirect(reqwest::redirect::Policy::none())
|
||||||
|
.danger_accept_invalid_certs(true)
|
||||||
|
.cookie_store(true)
|
||||||
|
.no_proxy()
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let cleaner = Cleaner::new(config.application.data_dir);
|
||||||
|
|
||||||
|
TestServer {
|
||||||
|
state,
|
||||||
|
api_client,
|
||||||
|
address,
|
||||||
|
ws_addr,
|
||||||
|
port,
|
||||||
|
cleaner,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn configure_database(config: &DatabaseSetting) -> PgPool {
|
||||||
|
// Create database
|
||||||
|
let mut connection = PgConnection::connect_with(&config.without_db())
|
||||||
|
.await
|
||||||
|
.expect("Failed to connect to Postgres");
|
||||||
|
connection
|
||||||
|
.execute(&*format!(r#"CREATE DATABASE "{}";"#, config.database_name))
|
||||||
|
.await
|
||||||
|
.expect("Failed to create database.");
|
||||||
|
|
||||||
|
// Migrate database
|
||||||
|
let connection_pool = PgPool::connect_with(config.with_db())
|
||||||
|
.await
|
||||||
|
.expect("Failed to connect to Postgres.");
|
||||||
|
|
||||||
|
sqlx::migrate!("./migrations")
|
||||||
|
.run(&connection_pool)
|
||||||
|
.await
|
||||||
|
.expect("Failed to migrate the database");
|
||||||
|
|
||||||
|
connection_pool
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Serialize)]
|
||||||
|
pub struct TestUser {
|
||||||
|
name: String,
|
||||||
|
pub email: String,
|
||||||
|
pub password: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TestUser {
|
||||||
|
pub fn generate() -> Self {
|
||||||
|
Self {
|
||||||
|
name: "Me".to_string(),
|
||||||
|
email: "me@appflowy.io".to_string(),
|
||||||
|
password: "Hello@AppFlowy123".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn register(&self, test_server: &TestServer) -> String {
|
||||||
|
let url = format!("{}/api/user/register", test_server.address);
|
||||||
|
let resp = test_server
|
||||||
|
.api_client
|
||||||
|
.post(&url)
|
||||||
|
.json(self)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.expect("Fail to register user");
|
||||||
|
|
||||||
|
let bytes = resp.bytes().await.unwrap();
|
||||||
|
let response: RegisterResponse = serde_json::from_slice(&bytes).unwrap();
|
||||||
|
response.token
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn error_msg_from_resp(resp: reqwest::Response) -> String {
|
||||||
|
let bytes = resp.bytes().await.unwrap();
|
||||||
|
String::from_utf8(bytes.to_vec()).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Cleaner {
|
||||||
|
path: PathBuf,
|
||||||
|
should_clean: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Cleaner {
|
||||||
|
fn new(path: PathBuf) -> Self {
|
||||||
|
Self {
|
||||||
|
path,
|
||||||
|
should_clean: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cleanup(dir: &PathBuf) {
|
||||||
|
let _ = std::fs::remove_dir_all(dir);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for Cleaner {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if self.should_clean {
|
||||||
|
Self::cleanup(&self.path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
1
tests/ws/mod.rs
Normal file
1
tests/ws/mod.rs
Normal file
|
@ -0,0 +1 @@
|
||||||
|
|
Loading…
Add table
Reference in a new issue