From 18e950a8296e2c37a8b65825e116cd83aa41cfc9 Mon Sep 17 00:00:00 2001 From: "Nathan.fooo" <86001920+appflowy@users.noreply.github.com> Date: Mon, 8 May 2023 19:03:50 +0800 Subject: [PATCH] feat: ws connect (#3) * chore: ws * chore: build client stream * feat: test ws connect * ci: fix ci --- .gitignore | 4 +- Cargo.lock | 449 ++++++++++++++++++++++++++-- Cargo.toml | 19 +- configuration/base.yaml | 1 + crates/revdb/Cargo.toml | 15 - crates/revdb/src/db.rs | 56 ---- crates/revdb/src/document.rs | 117 -------- crates/revdb/src/error.rs | 11 - crates/revdb/src/lib.rs | 4 - crates/revdb/src/range.rs | 48 --- crates/revdb/tests/document/mod.rs | 2 - crates/revdb/tests/document/test.rs | 136 --------- crates/revdb/tests/document/util.rs | 7 - crates/revdb/tests/main.rs | 1 - crates/snowflake/src/lib.rs | 93 +++--- crates/token/src/lib.rs | 78 ++--- crates/websocket/Cargo.toml | 25 ++ crates/websocket/src/client.rs | 168 +++++++++++ crates/websocket/src/entities.rs | 56 ++++ crates/websocket/src/error.rs | 8 + crates/websocket/src/lib.rs | 7 + crates/websocket/src/server.rs | 190 ++++++++++++ rustfmt.toml | 12 + src/api/user.rs | 120 ++++---- src/api/ws.rs | 44 +-- src/application.rs | 172 ++++++----- src/component/auth/error.rs | 92 +++--- src/component/auth/password.rs | 107 +++---- src/component/auth/user.rs | 392 ++++++++++++------------ src/component/mod.rs | 1 - src/component/token_state.rs | 36 +-- src/component/ws/client.rs | 164 ---------- src/component/ws/entities.rs | 92 ------ src/component/ws/mod.rs | 11 - src/component/ws/server.rs | 44 --- src/config/config.rs | 140 ++++----- src/config/env.rs | 6 +- src/domain/user_email.rs | 58 ++-- src/domain/user_name.rs | 118 ++++---- src/domain/user_password.rs | 60 ++-- src/main.rs | 18 +- src/middleware/cors.rs | 14 +- src/self_signed.rs | 40 +-- src/state.rs | 94 +++--- src/telemetry.rs | 38 +-- tests/api/login.rs | 52 ++-- tests/api/{main.rs => mod.rs} | 2 +- tests/api/password.rs | 76 ++--- tests/api/register.rs | 74 ++--- tests/api/test_server.rs | 189 ------------ tests/api/ws.rs | 13 + tests/main.rs | 3 + tests/util/mod.rs | 2 + tests/util/test_server.rs | 232 ++++++++++++++ tests/ws/mod.rs | 1 + 55 files changed, 2144 insertions(+), 1868 deletions(-) delete mode 100644 crates/revdb/Cargo.toml delete mode 100644 crates/revdb/src/db.rs delete mode 100644 crates/revdb/src/document.rs delete mode 100644 crates/revdb/src/error.rs delete mode 100644 crates/revdb/src/lib.rs delete mode 100644 crates/revdb/src/range.rs delete mode 100644 crates/revdb/tests/document/mod.rs delete mode 100644 crates/revdb/tests/document/test.rs delete mode 100644 crates/revdb/tests/document/util.rs delete mode 100644 crates/revdb/tests/main.rs create mode 100644 crates/websocket/Cargo.toml create mode 100644 crates/websocket/src/client.rs create mode 100644 crates/websocket/src/entities.rs create mode 100644 crates/websocket/src/error.rs create mode 100644 crates/websocket/src/lib.rs create mode 100644 crates/websocket/src/server.rs create mode 100644 rustfmt.toml delete mode 100644 src/component/ws/client.rs delete mode 100644 src/component/ws/entities.rs delete mode 100644 src/component/ws/mod.rs delete mode 100644 src/component/ws/server.rs rename tests/api/{main.rs => mod.rs} (69%) delete mode 100644 tests/api/test_server.rs create mode 100644 tests/api/ws.rs create mode 100644 tests/main.rs create mode 100644 tests/util/mod.rs create mode 100644 tests/util/test_server.rs create mode 100644 tests/ws/mod.rs diff --git a/.gitignore b/.gitignore index e1c34554..549746fd 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,6 @@ **/temp/** package-lock.json yarn.lock -node_modules \ No newline at end of file +node_modules +**/crates/AppFlowy-Collab/ +data/ \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index e30b4615..fb2a4e37 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -89,7 +89,7 @@ dependencies = [ "mime", "percent-encoding", "pin-project-lite", - "rand", + "rand 0.8.5", "sha1", "smallvec", "tokio", @@ -189,7 +189,7 @@ dependencies = [ "anyhow", "async-trait", "derive_more", - "rand", + "rand 0.8.5", "redis", "serde", "serde_json", @@ -370,7 +370,7 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" dependencies = [ - "getrandom", + "getrandom 0.2.9", "once_cell", "version_check", ] @@ -382,7 +382,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" dependencies = [ "cfg-if", - "getrandom", + "getrandom 0.2.9", "once_cell", "version_check", ] @@ -446,6 +446,9 @@ dependencies = [ "bincode", "bytes", "chrono", + "collab-client-ws", + "collab-persistence", + "collab-sync", "config", "dashmap", "derive_more", @@ -454,7 +457,7 @@ dependencies = [ "lazy_static", "once_cell", "openssl", - "rand", + "rand 0.8.5", "rcgen", "reqwest", "secrecy", @@ -474,6 +477,7 @@ dependencies = [ "unicode-segmentation", "uuid", "validator", + "websocket", ] [[package]] @@ -574,6 +578,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "atomic_refcell" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79d6dc922a2792b006573f60b2648076355daeae5ce9cb59507e5908c9625d31" + [[package]] name = "autocfg" version = "1.1.0" @@ -613,6 +623,26 @@ dependencies = [ "serde", ] +[[package]] +name = "bindgen" +version = "0.64.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4243e6031260db77ede97ad86c27e501d646a27ab57b59a574f725d98ab1fb4" +dependencies = [ + "bitflags", + "cexpr", + "clang-sys", + "lazy_static", + "lazycell", + "peeking_take_while", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "syn 1.0.109", +] + [[package]] name = "bit-set" version = "0.5.3" @@ -700,6 +730,17 @@ dependencies = [ "bytes", ] +[[package]] +name = "bzip2-sys" +version = "0.1.11+1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "cc" version = "1.0.79" @@ -709,6 +750,15 @@ dependencies = [ "jobserver", ] +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + [[package]] name = "cfg-if" version = "1.0.0" @@ -738,6 +788,17 @@ dependencies = [ "inout", ] +[[package]] +name = "clang-sys" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c688fc74432808e3eb684cae8830a86be1d66a2bd58e1f248ed0960a590baf6f" +dependencies = [ + "glob", + "libc", + "libloading", +] + [[package]] name = "codespan-reporting" version = "0.11.1" @@ -748,6 +809,91 @@ dependencies = [ "unicode-width", ] +[[package]] +name = "collab" +version = "0.1.0" +dependencies = [ + "anyhow", + "bytes", + "lib0", + "parking_lot 0.12.1", + "serde", + "serde_json", + "thiserror", + "tracing", + "y-sync", + "yrs", +] + +[[package]] +name = "collab-client-ws" +version = "0.1.0" +dependencies = [ + "bytes", + "futures-util", + "serde", + "serde_json", + "thiserror", + "tokio", + "tokio-retry", + "tokio-stream", + "tokio-tungstenite", + "tracing", +] + +[[package]] +name = "collab-persistence" +version = "0.1.0" +dependencies = [ + "bincode", + "chrono", + "lazy_static", + "lib0", + "parking_lot 0.12.1", + "rocksdb", + "serde", + "sled", + "smallvec", + "thiserror", + "tokio", + "tracing", + "yrs", +] + +[[package]] +name = "collab-plugins" +version = "0.1.0" +dependencies = [ + "collab", + "collab-client-ws", + "collab-persistence", + "collab-sync", + "tracing", + "y-sync", + "yrs", +] + +[[package]] +name = "collab-sync" +version = "0.1.0" +dependencies = [ + "bytes", + "collab", + "futures-util", + "lib0", + "md5", + "parking_lot 0.12.1", + "serde", + "serde_json", + "thiserror", + "tokio", + "tokio-stream", + "tokio-util", + "tracing", + "y-sync", + "yrs", +] + [[package]] name = "combine" version = "4.6.6" @@ -793,7 +939,7 @@ dependencies = [ "hkdf", "hmac", "percent-encoding", - "rand", + "rand 0.8.5", "sha2", "subtle", "time", @@ -914,7 +1060,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" dependencies = [ "generic-array", - "rand_core", + "rand_core 0.6.4", "typenum", ] @@ -1308,6 +1454,19 @@ dependencies = [ "winapi", ] +[[package]] +name = "getrandom" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" +dependencies = [ + "cfg-if", + "js-sys", + "libc", + "wasi 0.9.0+wasi-snapshot-preview1", + "wasm-bindgen", +] + [[package]] name = "getrandom" version = "0.2.9" @@ -1316,7 +1475,7 @@ checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4" dependencies = [ "cfg-if", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", ] [[package]] @@ -1329,6 +1488,12 @@ dependencies = [ "polyval", ] +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + [[package]] name = "h2" version = "0.3.18" @@ -1635,12 +1800,65 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + +[[package]] +name = "lib0" +version = "0.16.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf23122cb1c970b77ea6030eac5e328669415b65d2ab245c99bfb110f9d62dc" +dependencies = [ + "serde", + "serde_json", + "thiserror", +] + [[package]] name = "libc" version = "0.2.142" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a987beff54b60ffa6d51982e1aa1146bc42f19bd26be28b0586f252fccf5317" +[[package]] +name = "libloading" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b67380fd3b2fbe7527a606e18729d21c6f3951633d0500574c4dc22d2d638b9f" +dependencies = [ + "cfg-if", + "winapi", +] + +[[package]] +name = "librocksdb-sys" +version = "0.10.0+7.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fe4d5874f5ff2bc616e55e8c6086d478fcda13faf9495768a4aa1c22042d30b" +dependencies = [ + "bindgen", + "bzip2-sys", + "cc", + "glob", + "libc", + "libz-sys", + "zstd-sys", +] + +[[package]] +name = "libz-sys" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56ee889ecc9568871456d42f603d6a0ce59ff328d291063a45cbdf0036baf6db" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "link-cplusplus" version = "1.0.8" @@ -1723,6 +1941,12 @@ dependencies = [ "digest", ] +[[package]] +name = "md5" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" + [[package]] name = "memchr" version = "2.5.0" @@ -1767,7 +1991,7 @@ checksum = "5b9d9a46eff5b4ff64b45a9e316a6d1e0bc719ef429cbec4dc630684212bfdf9" dependencies = [ "libc", "log", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "windows-sys 0.45.0", ] @@ -1975,7 +2199,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166" dependencies = [ "base64ct", - "rand_core", + "rand_core 0.6.4", "subtle", ] @@ -1991,6 +2215,12 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8835116a5c179084a830efb3adc117ab007512b535bc1a21c991d3b32a6b44dd" +[[package]] +name = "peeking_take_while" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" + [[package]] name = "pem" version = "1.1.1" @@ -2096,6 +2326,19 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" +dependencies = [ + "getrandom 0.1.16", + "libc", + "rand_chacha 0.2.2", + "rand_core 0.5.1", + "rand_hc", +] + [[package]] name = "rand" version = "0.8.5" @@ -2103,8 +2346,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha", - "rand_core", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402" +dependencies = [ + "ppv-lite86", + "rand_core 0.5.1", ] [[package]] @@ -2114,7 +2367,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_core" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" +dependencies = [ + "getrandom 0.1.16", ] [[package]] @@ -2123,7 +2385,16 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", + "getrandom 0.2.9", +] + +[[package]] +name = "rand_hc" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" +dependencies = [ + "rand_core 0.5.1", ] [[package]] @@ -2186,7 +2457,7 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" dependencies = [ - "getrandom", + "getrandom 0.2.9", "redox_syscall 0.2.16", "thiserror", ] @@ -2264,17 +2535,6 @@ dependencies = [ "winreg", ] -[[package]] -name = "revdb" -version = "0.1.0" -dependencies = [ - "bincode", - "serde", - "sled", - "tempfile", - "thiserror", -] - [[package]] name = "ring" version = "0.16.20" @@ -2290,6 +2550,22 @@ dependencies = [ "winapi", ] +[[package]] +name = "rocksdb" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "015439787fce1e75d55f279078d33ff14b4af5d93d995e8838ee4631301c8a99" +dependencies = [ + "libc", + "librocksdb-sys", +] + +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rustc_version" version = "0.4.0" @@ -2504,6 +2780,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shlex" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43b2853a4d09f215c24cc5489c992ce46052d359b5109343cbafbf26bc62f8a3" + [[package]] name = "signal-hook-registry" version = "1.4.1" @@ -2538,6 +2820,15 @@ dependencies = [ "parking_lot 0.11.2", ] +[[package]] +name = "smallstr" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e922794d168678729ffc7e07182721a14219c65814e66e91b839a272fe5ae4f" +dependencies = [ + "smallvec", +] + [[package]] name = "smallvec" version = "1.10.0" @@ -2621,7 +2912,7 @@ dependencies = [ "once_cell", "paste", "percent-encoding", - "rand", + "rand 0.8.5", "rustls", "rustls-pemfile", "serde", @@ -2881,6 +3172,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-retry" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f57eb36ecbe0fc510036adff84824dd3c24bb781e21bfa67b69d556aa85214f" +dependencies = [ + "pin-project", + "rand 0.8.5", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.23.4" @@ -2901,6 +3203,19 @@ dependencies = [ "futures-core", "pin-project-lite", "tokio", + "tokio-util", +] + +[[package]] +name = "tokio-tungstenite" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54319c93411147bced34cb5609a80e0a8e44c5999c93903a81cd866630ec0bfd" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", ] [[package]] @@ -3035,6 +3350,25 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" +[[package]] +name = "tungstenite" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30ee6ab729cd4cf0fd55218530c4522ed30b7b6081752839b68fcec8d0960788" +dependencies = [ + "base64 0.13.1", + "byteorder", + "bytes", + "http", + "httparse", + "log", + "rand 0.8.5", + "sha1", + "thiserror", + "url", + "utf-8", +] + [[package]] name = "typenum" version = "1.16.0" @@ -3113,13 +3447,19 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "uuid" version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4dad5567ad0cf5b760e5665964bec1b47dfd077ba8a2544b513f3556d3d239a2" dependencies = [ - "getrandom", + "getrandom 0.2.9", "serde", ] @@ -3166,6 +3506,12 @@ dependencies = [ "try-lock", ] +[[package]] +name = "wasi" +version = "0.9.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -3267,6 +3613,28 @@ dependencies = [ "webpki", ] +[[package]] +name = "websocket" +version = "0.1.0" +dependencies = [ + "actix", + "actix-web-actors", + "bytes", + "collab", + "collab-persistence", + "collab-plugins", + "collab-sync", + "dashmap", + "futures-util", + "parking_lot 0.12.1", + "secrecy", + "serde", + "thiserror", + "tokio", + "tokio-stream", + "tracing", +] + [[package]] name = "whoami" version = "1.4.0" @@ -3492,6 +3860,17 @@ dependencies = [ "time", ] +[[package]] +name = "y-sync" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f54d34b68ec4514a0659838c2b1ba867c571b20b3804a1338dacf4fa9062d801" +dependencies = [ + "lib0", + "thiserror", + "yrs", +] + [[package]] name = "yaml-rust" version = "0.4.5" @@ -3510,6 +3889,20 @@ dependencies = [ "time", ] +[[package]] +name = "yrs" +version = "0.16.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c2aef2bf89b4f7c003f9c73f1c8097427ca32e1d006443f3f607f11e79a797b" +dependencies = [ + "atomic_refcell", + "lib0", + "rand 0.7.3", + "smallstr", + "smallvec", + "thiserror", +] + [[package]] name = "zeroize" version = "1.6.0" diff --git a/Cargo.toml b/Cargo.toml index c9fd8628..b58d1cf6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,6 +51,8 @@ bytes = "1.4.0" bincode = "1.3.3" dashmap = "5.4" rcgen = { version = "0.10.0", features = ["pem", "x509-parser"] } +collab-sync = {version = "0.1.0"} +collab-persistence = {version = "0.1.0"} # tracing tracing = { version = "0.1.37" } @@ -63,9 +65,11 @@ sqlx = { version = "0.6", default-features = false, features = ["runtime-actix-r #Local crate token = { path = "./crates/token" } snowflake = { path = "./crates/snowflake" } +websocket = { path = "./crates/websocket" } [dev-dependencies] once_cell = "1.7.2" +collab-client-ws = { version = "0.1.0" } [[bin]] name = "appflowy_server" @@ -78,6 +82,19 @@ path = "src/lib.rs" [workspace] members = [ "crates/token", - "crates/revdb", "crates/snowflake", + "crates/websocket", ] + +[patch.crates-io] +collab = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a12ed" } +collab-client-ws = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a12ed" } +collab-sync= { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a12ed" } +collab-persistence = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a12ed" } +collab-plugins = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a12ed" } + +#collab = { path = "./crates/AppFlowy-Collab/collab" } +#collab-client-ws = { path = "./crates/AppFlowy-Collab/collab-client-ws" } +#collab-sync = { path = "./crates/AppFlowy-Collab/collab-sync" } +#collab-persistence = { path = "./crates/AppFlowy-Collab/collab-persistence" } +#collab-plugins = { path = "./crates/AppFlowy-Collab/collab-plugins"} diff --git a/configuration/base.yaml b/configuration/base.yaml index de76dff1..422d327b 100644 --- a/configuration/base.yaml +++ b/configuration/base.yaml @@ -3,6 +3,7 @@ application: host: 0.0.0.0 server_key: "Should-Use-The-Custom-Server-Key" tls_config: "no_tls" + data_dir: "./data" database: host: "localhost" port: 5432 diff --git a/crates/revdb/Cargo.toml b/crates/revdb/Cargo.toml deleted file mode 100644 index 9d876298..00000000 --- a/crates/revdb/Cargo.toml +++ /dev/null @@ -1,15 +0,0 @@ -[package] -name = "revdb" -version = "0.1.0" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -sled = "0.34.7" -thiserror = "1.0.30" -serde = { version = "1.0", features = ["derive"] } -bincode = "1.3.3" - -[dev-dependencies] -tempfile = "3.4.0" \ No newline at end of file diff --git a/crates/revdb/src/db.rs b/crates/revdb/src/db.rs deleted file mode 100644 index 084951b3..00000000 --- a/crates/revdb/src/db.rs +++ /dev/null @@ -1,56 +0,0 @@ -use crate::document::Document; -use crate::error::RevDBError; -use sled::{Batch, Db, IVec}; -use std::path::Path; - -pub struct RevDB { - pub(crate) db: Db, -} - -impl RevDB { - pub fn open(path: impl AsRef) -> Result { - let db = sled::open(path)?; - Ok(Self { db }) - } - - pub fn document(&self) -> Document { - Document { db: self } - } - - pub fn get>(&self, key: K) -> Result, RevDBError> { - let value = self.db.get(key)?; - Ok(value) - } - - pub fn batch_get>( - &self, - from_key: K, - to_key: K, - ) -> Result, RevDBError> { - let iter = self.db.range(from_key..=to_key); - let mut items = vec![]; - for item in iter { - let (_, value) = item?; - items.push(value) - } - Ok(items) - } - - pub fn insert>(&self, key: K, value: &[u8]) -> Result<(), RevDBError> { - let _ = self.db.insert(key, value)?; - Ok(()) - } - - pub fn batch_insert<'a, K: AsRef<[u8]>>( - &self, - items: impl IntoIterator, - ) -> Result<(), RevDBError> { - let mut batch = Batch::default(); - let items = items.into_iter(); - items.for_each(|(key, value)| { - batch.insert(key.as_ref(), value); - }); - self.db.apply_batch(batch)?; - Ok(()) - } -} diff --git a/crates/revdb/src/document.rs b/crates/revdb/src/document.rs deleted file mode 100644 index 0d002bf1..00000000 --- a/crates/revdb/src/document.rs +++ /dev/null @@ -1,117 +0,0 @@ -use crate::db::RevDB; -use crate::error::RevDBError; -use crate::range::RevRange; -use serde::{Deserialize, Serialize}; - -pub struct Document<'a> { - pub(crate) db: &'a RevDB, -} - -impl<'a> Document<'a> { - pub fn insert( - &self, - uid: i64, - document_id: i64, - value: DocumentRevData, - ) -> Result<(), RevDBError> { - let key = make_document_key(uid, document_id, value.rev_id); - self.db.insert(key, &value.to_vec()?)?; - Ok(()) - } - - pub fn get( - &self, - uid: i64, - document_id: i64, - rev_id: i64, - ) -> Result, RevDBError> { - let key = make_document_key(uid, document_id, rev_id); - match self.db.get(key)? { - None => Ok(None), - Some(value) => { - let data = DocumentRevData::from_vec(value.as_ref())?; - Ok(Some(data)) - } - } - } - - pub fn get_with_range>( - &self, - uid: i64, - document_id: i64, - range: R, - ) -> Result, RevDBError> { - let range = range.into(); - let from = make_document_key(uid, document_id, range.start); - let to = make_document_key(uid, document_id, range.end); - self.batch_get(from, to) - } - - pub fn get_after( - &self, - uid: i64, - document_id: i64, - rev_id: i64, - ) -> Result, RevDBError> { - let from = make_document_key(uid, document_id, rev_id); - let to = make_document_key(uid, document_id, i64::MAX); - self.batch_get(from, to) - } - - pub fn get_before( - &self, - uid: i64, - document_id: i64, - rev_id: i64, - ) -> Result, RevDBError> { - let from = make_document_key(uid, document_id, 0); - let to = make_document_key(uid, document_id, rev_id); - self.batch_get(from, to) - } - - fn batch_get>( - &self, - from: K, - to: K, - ) -> Result, RevDBError> { - let items = self.db.batch_get(from, to)?; - let mut document_revs = vec![]; - for item in items { - let rev_data = DocumentRevData::from_vec(item.as_ref())?; - document_revs.push(rev_data); - } - Ok(document_revs) - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DocumentRevData { - #[serde(rename = "rid")] - pub rev_id: i64, - - #[serde(rename = "bid")] - pub base_rev_id: i64, - - #[serde(rename = "data")] - pub content: String, -} - -impl DocumentRevData { - pub fn from_vec(data: &[u8]) -> Result { - bincode::deserialize::(data).map_err(|_e| RevDBError::SerdeError) - } - - pub fn to_vec(&self) -> Result, RevDBError> { - bincode::serialize(self).map_err(|_e| RevDBError::SerdeError) - } -} - -// Optimize your data layout: Sled's B-Tree implementation works best when the keys are sequential, -// so try to organize the data in a way that maximizes sequential access. -fn make_document_key(uid: i64, document_id: i64, rev_id: i64) -> [u8; 24] { - let mut key = [0; 24]; - key[0..8].copy_from_slice(&uid.to_be_bytes()); - key[8..16].copy_from_slice(&document_id.to_be_bytes()); - key[16..24].copy_from_slice(&rev_id.to_be_bytes()); - key -} diff --git a/crates/revdb/src/error.rs b/crates/revdb/src/error.rs deleted file mode 100644 index f3a2b9e6..00000000 --- a/crates/revdb/src/error.rs +++ /dev/null @@ -1,11 +0,0 @@ -#[derive(Debug, thiserror::Error)] -pub enum RevDBError { - #[error(transparent)] - Db(#[from] sled::Error), - - #[error("Serde error")] - SerdeError, - - #[error("invalid data")] - InvalidData, -} diff --git a/crates/revdb/src/lib.rs b/crates/revdb/src/lib.rs deleted file mode 100644 index d6cf3006..00000000 --- a/crates/revdb/src/lib.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub mod db; -pub mod document; -pub mod error; -pub mod range; diff --git a/crates/revdb/src/range.rs b/crates/revdb/src/range.rs deleted file mode 100644 index 0e46e3ef..00000000 --- a/crates/revdb/src/range.rs +++ /dev/null @@ -1,48 +0,0 @@ -use std::ops::{Range, RangeInclusive, RangeToInclusive}; - -#[derive(Clone)] -pub struct RevRange { - pub(crate) start: i64, - pub(crate) end: i64, -} - -impl RevRange { - /// Construct a new `RevRange` representing the range [start..end). - /// It is an invariant that `start <= end`. - pub fn new(start: i64, end: i64) -> RevRange { - debug_assert!(start <= end); - RevRange { start, end } - } -} - -impl From> for RevRange { - fn from(src: RangeInclusive) -> RevRange { - RevRange::new(*src.start(), src.end().saturating_add(1)) - } -} - -impl From> for RevRange { - fn from(src: RangeToInclusive) -> RevRange { - RevRange::new(0, src.end.saturating_add(1)) - } -} - -impl From> for RevRange { - fn from(src: Range) -> RevRange { - let Range { start, end } = src; - RevRange { start, end } - } -} - -impl Iterator for RevRange { - type Item = i64; - - fn next(&mut self) -> Option { - if self.start > self.end { - return None; - } - let val = self.start; - self.start += 1; - Some(val) - } -} diff --git a/crates/revdb/tests/document/mod.rs b/crates/revdb/tests/document/mod.rs deleted file mode 100644 index 89c3cfa9..00000000 --- a/crates/revdb/tests/document/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -mod test; -mod util; diff --git a/crates/revdb/tests/document/test.rs b/crates/revdb/tests/document/test.rs deleted file mode 100644 index 8dbf51d1..00000000 --- a/crates/revdb/tests/document/test.rs +++ /dev/null @@ -1,136 +0,0 @@ -use crate::document::util::make_test_db; -use revdb::document::{Document, DocumentRevData}; -use revdb::range::RevRange; - -#[test] -fn insert_text() { - let db = make_test_db(); - let document = db.document(); - let uid = 12345678; - let document_id = 1; - let value = DocumentRevData { - rev_id: 0, - base_rev_id: 0, - content: "hello world".to_string(), - }; - document.insert(uid, document_id, value.clone()).unwrap(); - - let restored_data = document.get(uid, document_id, 0).unwrap().unwrap(); - assert_eq!(value.content, restored_data.content); -} - -//noinspection RsExternalLinter -#[test] -fn insert_multi_text() { - let db = make_test_db(); - let document = db.document(); - let uid = 12345678; - let document_id = 1; - - let mut base_rev_id = 0; - let mut expected_str = "".to_string(); - for i in 0..=100 { - let content = i.to_string(); - expected_str.push_str(&content); - let value = DocumentRevData { - rev_id: i, - base_rev_id, - content, - }; - base_rev_id += 1; - document.insert(uid, document_id, value).unwrap(); - } - // - let restored_str = document - .get_with_range(uid, document_id, RevRange::new(0, 100)) - .unwrap() - .into_iter() - .map(|data| data.content) - .collect::>() - .join(""); - - assert_eq!(expected_str, restored_str); -} - -//noinspection RsExternalLinter -fn insert_100_string_to_document(uid: i64, document_id: i64, document: &Document) { - let mut base_rev_id = 0; - for i in 0..=100 { - let content = i.to_string(); - let value = DocumentRevData { - rev_id: i, - base_rev_id, - content, - }; - base_rev_id += 1; - document.insert(uid, document_id, value).unwrap(); - } -} - -fn values_to_string(values: Vec) -> String { - values - .into_iter() - .map(|data| data.content) - .collect::>() - .join("") -} - -#[test] -fn get_value_before() { - let db = make_test_db(); - let document = db.document(); - let uid = 12345678; - let document_id = 1; - insert_100_string_to_document(uid, document_id, &document); - - let restored_str = values_to_string(document.get_before(uid, document_id, 50).unwrap()); - assert_eq!("01234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950", restored_str); - - let restored_str = values_to_string(document.get_before(uid, document_id, 0).unwrap()); - assert_eq!("0", restored_str); -} - -#[test] -fn get_value_after() { - let db = make_test_db(); - let document = db.document(); - let uid = 12345678; - let document_id = 1; - insert_100_string_to_document(uid, document_id, &document); - - let restored_str = values_to_string(document.get_after(uid, document_id, 50).unwrap()); - assert_eq!("5051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100", restored_str); - - let restored_str = values_to_string(document.get_after(uid, document_id, 100).unwrap()); - assert_eq!("100", restored_str); -} - -#[test] -fn get_value_with_range() { - let db = make_test_db(); - let document = db.document(); - let uid = 12345678; - let document_id = 1; - insert_100_string_to_document(uid, document_id, &document); - - let restored_str = values_to_string( - document - .get_with_range(uid, document_id, RevRange::new(50, 60)) - .unwrap(), - ); - assert_eq!("5051525354555657585960", restored_str); - - let restored_str = values_to_string( - document - .get_with_range(uid, document_id, RevRange::new(50, 200)) - .unwrap(), - ); - assert_eq!("5051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100", restored_str); - - let restored_str = values_to_string( - document - .get_with_range(uid, document_id, RevRange::new(50, 50)) - .unwrap(), - ); - assert_eq!("50", restored_str); -} diff --git a/crates/revdb/tests/document/util.rs b/crates/revdb/tests/document/util.rs deleted file mode 100644 index 75b45246..00000000 --- a/crates/revdb/tests/document/util.rs +++ /dev/null @@ -1,7 +0,0 @@ -use revdb::db::RevDB; -use tempfile::TempDir; - -pub fn make_test_db() -> RevDB { - let tempdir = TempDir::new().unwrap(); - RevDB::open(tempdir).unwrap() -} diff --git a/crates/revdb/tests/main.rs b/crates/revdb/tests/main.rs deleted file mode 100644 index 10331894..00000000 --- a/crates/revdb/tests/main.rs +++ /dev/null @@ -1 +0,0 @@ -mod document; diff --git a/crates/snowflake/src/lib.rs b/crates/snowflake/src/lib.rs index b35b209c..b8f10916 100644 --- a/crates/snowflake/src/lib.rs +++ b/crates/snowflake/src/lib.rs @@ -8,66 +8,65 @@ const TIMESTAMP_SHIFT: u64 = NODE_ID_BITS + SEQUENCE_BITS; const SEQUENCE_MASK: u64 = (1 << SEQUENCE_BITS) - 1; pub struct Snowflake { - node_id: u64, - sequence: u64, - last_timestamp: u64, + node_id: u64, + sequence: u64, + last_timestamp: u64, } impl Snowflake { - pub fn new(node_id: u64) -> Snowflake { - Snowflake { - node_id, - sequence: 0, - last_timestamp: 0, - } + pub fn new(node_id: u64) -> Snowflake { + Snowflake { + node_id, + sequence: 0, + last_timestamp: 0, + } + } + + pub fn next_id(&mut self) -> i64 { + let timestamp = self.timestamp(); + if timestamp < self.last_timestamp { + panic!("Clock moved backwards!"); } - pub fn next_id(&mut self) -> i64 { - let timestamp = self.timestamp(); - if timestamp < self.last_timestamp { - panic!("Clock moved backwards!"); - } - - if timestamp == self.last_timestamp { - self.sequence = (self.sequence + 1) & SEQUENCE_MASK; - if self.sequence == 0 { - self.wait_next_millis(); - } - } else { - self.sequence = 0; - } - - self.last_timestamp = timestamp; - let id = - (timestamp - EPOCH) << TIMESTAMP_SHIFT | self.node_id << NODE_ID_SHIFT | self.sequence; - id as i64 + if timestamp == self.last_timestamp { + self.sequence = (self.sequence + 1) & SEQUENCE_MASK; + if self.sequence == 0 { + self.wait_next_millis(); + } + } else { + self.sequence = 0; } - fn wait_next_millis(&self) { - let mut timestamp = self.timestamp(); - while timestamp == self.last_timestamp { - timestamp = self.timestamp(); - } - } + self.last_timestamp = timestamp; + let id = (timestamp - EPOCH) << TIMESTAMP_SHIFT | self.node_id << NODE_ID_SHIFT | self.sequence; + id as i64 + } - fn timestamp(&self) -> u64 { - SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .expect("Clock moved backwards!") - .as_millis() as u64 + fn wait_next_millis(&self) { + let mut timestamp = self.timestamp(); + while timestamp == self.last_timestamp { + timestamp = self.timestamp(); } + } + + fn timestamp(&self) -> u64 { + SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .expect("Clock moved backwards!") + .as_millis() as u64 + } } #[cfg(test)] mod tests { - use crate::Snowflake; + use crate::Snowflake; - #[test] - fn gen_id() { - let mut snow_flake = Snowflake::new(1); - let id_1 = snow_flake.next_id(); - let id_2 = snow_flake.next_id(); + #[test] + fn gen_id() { + let mut snow_flake = Snowflake::new(1); + let id_1 = snow_flake.next_id(); + let id_2 = snow_flake.next_id(); - assert_ne!(id_1, id_2); - } + assert_ne!(id_1, id_2); + } } diff --git a/crates/token/src/lib.rs b/crates/token/src/lib.rs index 895f7314..0096741f 100644 --- a/crates/token/src/lib.rs +++ b/crates/token/src/lib.rs @@ -7,70 +7,72 @@ use sha2::Sha256; #[derive(Debug, thiserror::Error)] pub enum TokenError { - #[error(transparent)] - Jwt(#[from] jwt::Error), + #[error(transparent)] + Jwt(#[from] jwt::Error), - #[error("Token expired")] - Expired, + #[error("Token expired")] + Expired, } #[derive(Debug, Serialize, Deserialize, Eq, PartialEq)] pub enum TokenType { - AccessToken, + AccessToken, } #[derive(Debug, Serialize, Deserialize)] struct TokenFields { - #[serde(rename = "d")] - data: T, + #[serde(rename = "d")] + data: T, - #[serde(rename = "exp")] - expire_at: DateTime, + #[serde(rename = "exp")] + expire_at: DateTime, } pub fn create_token( - server_key: &str, - data: impl Serialize, - expire_duration: Duration, + server_key: &str, + data: impl Serialize, + expire_duration: Duration, ) -> Result { - Ok(TokenFields { - data, - expire_at: Utc::now() + expire_duration, + Ok( + TokenFields { + data, + expire_at: Utc::now() + expire_duration, } - .sign_with_key(&generate_hmac_key(server_key))?) + .sign_with_key(&generate_hmac_key(server_key))?, + ) } fn generate_hmac_key(server_key: &str) -> Hmac { - Hmac::::new_from_slice(server_key.as_bytes()).expect("invalid server key") + Hmac::::new_from_slice(server_key.as_bytes()).expect("invalid server key") } pub fn parse_token(server_key: &str, token: &str) -> Result { - let fields = - VerifyWithKey::>::verify_with_key(token, &generate_hmac_key(server_key))?; - if fields.expire_at < Utc::now() { - return Err(TokenError::Expired); - } - Ok(fields.data) + let fields = + VerifyWithKey::>::verify_with_key(token, &generate_hmac_key(server_key))?; + if fields.expire_at < Utc::now() { + return Err(TokenError::Expired); + } + Ok(fields.data) } #[cfg(test)] mod tests { - use super::*; + use super::*; - #[test] - fn create_token_test() { - let token_data = "hello appflowy".to_string(); - let token = create_token("server_key", &token_data, Duration::days(2)).unwrap(); + #[test] + fn create_token_test() { + let token_data = "hello appflowy".to_string(); + let token = create_token("server_key", &token_data, Duration::days(2)).unwrap(); - let parse_token_data = parse_token::("server_key", &token).unwrap(); - assert_eq!(token_data, parse_token_data); - } + let parse_token_data = parse_token::("server_key", &token).unwrap(); + assert_eq!(token_data, parse_token_data); + } - #[test] - #[should_panic] - fn parser_token_with_different_server_key() { - let server_key = "123456"; - let token = create_token(server_key, "hello", Duration::days(2)).unwrap(); - let _ = parse_token::("abcdef", &token).unwrap(); - } + #[test] + #[should_panic] + fn parser_token_with_different_server_key() { + let server_key = "123456"; + let token = create_token(server_key, "hello", Duration::days(2)).unwrap(); + let _ = parse_token::("abcdef", &token).unwrap(); + } } diff --git a/crates/websocket/Cargo.toml b/crates/websocket/Cargo.toml new file mode 100644 index 00000000..64d85c68 --- /dev/null +++ b/crates/websocket/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "websocket" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +actix = "0.13" +actix-web-actors = { version = "4.2.0" } +serde = { version = "1.0", features = ["derive"] } +thiserror = "1.0.30" +bytes = "1.0" +secrecy = { version = "0.8", features = ["serde"] } +parking_lot = "0.12.1" +tracing = "0.1.25" +futures-util = "0.3.26" +tokio-stream = { version = "0.1.14", features = ["sync"] } +tokio = { version = "1.26", features = ["sync"] } +dashmap = "5.4.0" + +collab = { version = "0.1.0"} +collab-sync = { version = "0.1.0"} +collab-persistence = { version = "0.1.0"} +collab-plugins = { version = "0.1.0", features = ["disk_rocksdb"]} diff --git a/crates/websocket/src/client.rs b/crates/websocket/src/client.rs new file mode 100644 index 00000000..551d0a61 --- /dev/null +++ b/crates/websocket/src/client.rs @@ -0,0 +1,168 @@ +use crate::entities::{ClientMessage, Connect, Disconnect, ServerMessage, WSUser}; +use crate::error::WSError; +use crate::CollabServer; +use actix::{ + fut, Actor, ActorContext, ActorFutureExt, Addr, AsyncContext, ContextFutureSpawner, Handler, + Recipient, Running, StreamHandler, WrapFuture, +}; +use actix_web_actors::ws; +use bytes::Bytes; + +use collab_sync::msg::CollabMessage; +use futures_util::Sink; + +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::time::{Duration, Instant}; + +const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); +const CLIENT_TIMEOUT: Duration = Duration::from_secs(10); + +pub struct CollabSession { + user: Arc, + hb: Instant, + pub server: Addr, +} + +impl CollabSession { + pub fn new(user: WSUser, server: Addr) -> Self { + Self { + user: Arc::new(user), + hb: Instant::now(), + server, + } + } + + fn hb(&self, ctx: &mut ws::WebsocketContext) { + ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| { + if Instant::now().duration_since(act.hb) > CLIENT_TIMEOUT { + act.server.do_send(Disconnect { + user: act.user.clone(), + }); + ctx.stop(); + return; + } + + ctx.ping(b""); + }); + } + + fn send_to_server(&self, bytes: Bytes) { + match CollabMessage::from_vec(bytes.to_vec()) { + Ok(collab_msg) => { + self.server.do_send(ClientMessage { + user: self.user.clone(), + collab_msg, + }); + }, + Err(e) => { + tracing::error!("Error parsing message: {:?}", e); + }, + } + } +} + +impl Actor for CollabSession { + type Context = ws::WebsocketContext; + + fn started(&mut self, ctx: &mut Self::Context) { + // start heartbeats otherwise server disconnects in 10 seconds + self.hb(ctx); + + self + .server + .send(Connect { + socket: ctx.address().recipient(), + user: self.user.clone(), + }) + .into_actor(self) + .then(|res, _session, ctx| { + match res { + Ok(Ok(_)) => { + tracing::trace!("Send connect message to server success") + }, + _ => { + tracing::error!("Send connect message to server failed"); + ctx.stop(); + }, + } + fut::ready(()) + }) + .wait(ctx); + } + + fn stopping(&mut self, _: &mut Self::Context) -> Running { + self.server.do_send(Disconnect { + user: self.user.clone(), + }); + Running::Stop + } +} + +impl Handler for CollabSession { + type Result = (); + + fn handle(&mut self, msg: ServerMessage, ctx: &mut Self::Context) { + ctx.binary(msg.collab_msg); + } +} + +/// WebSocket message handler +impl StreamHandler> for CollabSession { + fn handle(&mut self, msg: Result, ctx: &mut Self::Context) { + let msg = match msg { + Err(_) => { + ctx.stop(); + return; + }, + Ok(msg) => msg, + }; + + match msg { + ws::Message::Ping(msg) => { + self.hb = Instant::now(); + ctx.pong(&msg); + }, + ws::Message::Pong(_) => { + self.hb = Instant::now(); + }, + ws::Message::Text(_) => {}, + ws::Message::Binary(bytes) => { + self.send_to_server(bytes); + }, + ws::Message::Close(reason) => { + ctx.close(reason); + ctx.stop(); + }, + ws::Message::Continuation(_) => { + ctx.stop(); + }, + ws::Message::Nop => (), + } + } +} + +/// A helper struct that wraps the [Recipient] type to implement the [Sink] trait +pub struct ClientSink(pub Recipient); + +impl Sink for ClientSink { + type Error = WSError; + + fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: CollabMessage) -> Result<(), Self::Error> { + self.0.do_send(ServerMessage { collab_msg: item }); + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} diff --git a/crates/websocket/src/entities.rs b/crates/websocket/src/entities.rs new file mode 100644 index 00000000..3ba6cf54 --- /dev/null +++ b/crates/websocket/src/entities.rs @@ -0,0 +1,56 @@ +use crate::error::WSError; +use actix::{Message, Recipient}; + +use collab_sync::msg::CollabMessage; +use secrecy::{ExposeSecret, Secret}; + +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct WSUser { + pub user_id: Secret, +} + +impl Hash for WSUser { + fn hash(&self, state: &mut H) { + let uid: &String = self.user_id.expose_secret(); + uid.hash(state); + } +} + +impl PartialEq for WSUser { + fn eq(&self, other: &Self) -> bool { + let uid: &String = self.user_id.expose_secret(); + let other_uid: &String = other.user_id.expose_secret(); + uid == other_uid + } +} + +impl Eq for WSUser {} + +#[derive(Debug, Message, Clone)] +#[rtype(result = "Result<(), WSError>")] +pub struct Connect { + pub socket: Recipient, + pub user: Arc, +} + +#[derive(Debug, Message, Clone)] +#[rtype(result = "Result<(), WSError>")] +pub struct Disconnect { + pub user: Arc, +} + +#[derive(Debug, Message, Clone)] +#[rtype(result = "()")] +pub struct ClientMessage { + pub user: Arc, + pub collab_msg: CollabMessage, +} + +#[derive(Debug, Message, Clone)] +#[rtype(result = "()")] +pub struct ServerMessage { + pub collab_msg: CollabMessage, +} diff --git a/crates/websocket/src/error.rs b/crates/websocket/src/error.rs new file mode 100644 index 00000000..54729bf7 --- /dev/null +++ b/crates/websocket/src/error.rs @@ -0,0 +1,8 @@ +#[derive(Debug, thiserror::Error)] +pub enum WSError { + #[error(transparent)] + Persistence(#[from] collab_persistence::error::PersistenceError), + + #[error("Internal failure: {0}")] + Internal(#[from] Box), +} diff --git a/crates/websocket/src/lib.rs b/crates/websocket/src/lib.rs new file mode 100644 index 00000000..e98c29ca --- /dev/null +++ b/crates/websocket/src/lib.rs @@ -0,0 +1,7 @@ +mod client; +pub mod entities; +mod error; +mod server; + +pub use client::*; +pub use server::*; diff --git a/crates/websocket/src/server.rs b/crates/websocket/src/server.rs new file mode 100644 index 00000000..8eb573af --- /dev/null +++ b/crates/websocket/src/server.rs @@ -0,0 +1,190 @@ +use crate::entities::{ClientMessage, Connect, Disconnect, WSUser}; +use crate::error::WSError; +use crate::ClientSink; + +use actix::{Actor, Context, Handler, ResponseFuture}; +use collab::core::collab::MutexCollab; +use collab::core::origin::CollabOrigin; +use collab_persistence::kv::rocks_kv::RocksCollabDB; +use collab_persistence::kv::KVStore; +use collab_plugins::disk_plugin::rocksdb_server::RocksdbServerDiskPlugin; +use collab_sync::server::{ + CollabBroadcast, CollabGroup, CollabIDGen, CollabId, NonZeroNodeId, COLLAB_ID_LEN, +}; +use dashmap::DashMap; +use parking_lot::{Mutex, RwLock}; +use std::collections::HashMap; + +use collab_persistence::keys::make_collab_id_key; +use collab_sync::msg::CollabMessage; + +use std::sync::Arc; +use tokio::sync::mpsc::Sender; +use tokio_stream::wrappers::ReceiverStream; + +#[derive(Clone)] +pub struct CollabServer { + db: Arc, + /// Generate collab_id for new collab object + collab_id_gen: Arc>, + /// Memory cache for fast lookup of collab_id from object_id + collab_id_by_object_id: Arc>, + collab_groups: Arc>>, + client_streams: Arc, ClientStream>>>, +} + +impl CollabServer { + pub fn new(db: Arc) -> Result { + let collab_id_gen = Arc::new(Mutex::new(CollabIDGen::new(NonZeroNodeId(1)))); + let collab_id_by_object_id = Arc::new(DashMap::new()); + Ok(Self { + db, + collab_id_gen, + collab_id_by_object_id, + collab_groups: Default::default(), + client_streams: Default::default(), + }) + } + + fn create_collab_id(&self, object_id: &str) -> Result { + let collab_id = self.collab_id_gen.lock().next_id(); + let collab_key = make_collab_id_key(object_id.as_ref()); + self.db.with_write_txn(|w_txn| { + w_txn.insert(collab_key.as_ref(), collab_id.to_be_bytes())?; + Ok(()) + })?; + Ok(collab_id) + } + + fn get_collab_id(&self, object_id: &str) -> Option { + let collab_key = make_collab_id_key(object_id.as_ref()); + let read_txn = self.db.read_txn(); + let value = read_txn.get(collab_key.as_ref()).ok()??; + + let mut bytes = [0; COLLAB_ID_LEN]; + bytes[0..COLLAB_ID_LEN].copy_from_slice(value.as_ref()); + Some(CollabId::from_be_bytes(bytes)) + } + + fn get_or_create_collab_id(&self, object_id: &str) -> Result { + let collab_id = self.get_collab_id(object_id); + if let Some(collab_id) = collab_id { + self.create_group_if_need(collab_id, object_id); + Ok(collab_id) + } else { + let collab_id = self.create_collab_id(object_id)?; + self + .collab_id_by_object_id + .insert(object_id.to_string(), collab_id); + self.create_group_if_need(collab_id, object_id); + Ok(collab_id) + } + } + + fn create_group_if_need(&self, collab_id: CollabId, object_id: &str) { + if self.collab_groups.read().contains_key(&collab_id) { + return; + } + + let collab = MutexCollab::new(CollabOrigin::Empty, object_id, vec![]); + let plugin = RocksdbServerDiskPlugin::new(collab_id, self.db.clone()).unwrap(); + collab.lock().add_plugin(Arc::new(plugin)); + collab.initial(); + + let broadcast = CollabBroadcast::new(object_id, collab.clone(), 10); + let group = CollabGroup { + collab, + broadcast, + subscribers: Default::default(), + }; + self.collab_groups.write().insert(collab_id, group); + } +} + +impl Actor for CollabServer { + type Context = Context; +} + +impl Handler for CollabServer { + type Result = Result<(), WSError>; + + fn handle(&mut self, msg: Connect, _ctx: &mut Context) -> Self::Result { + let (stream_tx, rx) = tokio::sync::mpsc::channel(100); + let stream = ClientStream::new(ClientSink(msg.socket), ReceiverStream::new(rx), stream_tx); + self.client_streams.write().insert(msg.user, stream); + Ok(()) + } +} + +impl Handler for CollabServer { + type Result = Result<(), WSError>; + fn handle(&mut self, msg: Disconnect, _: &mut Context) -> Self::Result { + self.client_streams.write().remove(&msg.user); + Ok(()) + } +} + +impl Handler for CollabServer { + type Result = ResponseFuture<()>; + + fn handle(&mut self, msg: ClientMessage, _ctx: &mut Context) -> Self::Result { + let object_id = msg.collab_msg.object_id(); + if let Ok(collab_id) = self.get_or_create_collab_id(object_id) { + if let Some(collab_group) = self.collab_groups.write().get_mut(&collab_id) { + if let Some(stream) = self.client_streams.write().get_mut(&msg.user) { + if let Some((sink, stream)) = stream.split() { + let origin = match msg.collab_msg.origin() { + None => CollabOrigin::Empty, + Some(client) => client.clone(), + }; + let sub = collab_group + .broadcast + .subscribe(origin.clone(), sink, stream); + collab_group.subscribers.insert(origin, sub); + } + } + } + + let client_streams = self.client_streams.clone(); + Box::pin(async move { + if let Some(client_stream) = client_streams.read().get(&msg.user) { + let _ = client_stream.stream_tx.send(Ok(msg.collab_msg)).await; + } + }) + } else { + Box::pin(async move {}) + } + } +} + +impl actix::Supervised for CollabServer { + fn restarting(&mut self, _ctx: &mut Context) { + tracing::warn!("restarting"); + } +} + +pub struct ClientStream { + sink: Option, + stream: Option>>, + stream_tx: Sender>, +} + +impl ClientStream { + pub fn new( + sink: ClientSink, + stream: ReceiverStream>, + stream_tx: Sender>, + ) -> Self { + Self { + sink: Some(sink), + stream: Some(stream), + stream_tx, + } + } + + pub fn split(&mut self) -> Option<(ClientSink, ReceiverStream>)> { + let sink = self.sink.take()?; + let stream = self.stream.take()?; + Some((sink, stream)) + } +} diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 00000000..5cb0d67e --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,12 @@ +# https://rust-lang.github.io/rustfmt/?version=master&search= +max_width = 100 +tab_spaces = 2 +newline_style = "Auto" +match_block_trailing_comma = true +use_field_init_shorthand = true +use_try_shorthand = true +reorder_imports = true +reorder_modules = true +remove_nested_parens = true +merge_derives = true +edition = "2021" \ No newline at end of file diff --git a/src/api/user.rs b/src/api/user.rs index ebf88191..b11bf174 100644 --- a/src/api/user.rs +++ b/src/api/user.rs @@ -1,6 +1,6 @@ use crate::component::auth::{ - change_password, logged_user_from_request, login, logout, register, ChangePasswordRequest, - InputParamsError, LoginRequest, RegisterRequest, + change_password, logged_user_from_request, login, logout, register, ChangePasswordRequest, + InputParamsError, LoginRequest, RegisterRequest, }; use crate::component::token_state::SessionToken; use crate::domain::{UserEmail, UserName, UserPassword}; @@ -11,83 +11,83 @@ use actix_web::{web, HttpResponse, Scope}; use actix_web::{HttpRequest, Result}; pub fn user_scope() -> Scope { - web::scope("/api/user") - .service(web::resource("/login").route(web::post().to(login_handler))) - .service(web::resource("/logout").route(web::get().to(logout_handler))) - .service(web::resource("/register").route(web::post().to(register_handler))) - .service(web::resource("/password").route(web::post().to(change_password_handler))) + web::scope("/api/user") + .service(web::resource("/login").route(web::post().to(login_handler))) + .service(web::resource("/logout").route(web::get().to(logout_handler))) + .service(web::resource("/register").route(web::post().to(register_handler))) + .service(web::resource("/password").route(web::post().to(change_password_handler))) } async fn login_handler( - req: Json, - state: Data, - session: SessionToken, + req: Json, + state: Data, + session: SessionToken, ) -> Result { - let req = req.into_inner(); - let email = UserEmail::parse(req.email) - .map_err(InputParamsError::InvalidEmail)? - .0; - let password = UserPassword::parse(req.password) - .map_err(|_| InputParamsError::InvalidPassword)? - .0; - let (resp, token) = login(email, password, &state).await?; + let req = req.into_inner(); + let email = UserEmail::parse(req.email) + .map_err(InputParamsError::InvalidEmail)? + .0; + let password = UserPassword::parse(req.password) + .map_err(|_| InputParamsError::InvalidPassword)? + .0; + let (resp, token) = login(email, password, &state).await?; - // Renews the session key, assigning existing session state to new key. - session.renew(); - if let Err(err) = session.insert_token(token) { - // It needs to navigate to login page in web application - tracing::error!("Insert session failed: {:?}", err); - } + // Renews the session key, assigning existing session state to new key. + session.renew(); + if let Err(err) = session.insert_token(token) { + // It needs to navigate to login page in web application + tracing::error!("Insert session failed: {:?}", err); + } - Ok(HttpResponse::Ok().json(resp)) + Ok(HttpResponse::Ok().json(resp)) } async fn logout_handler(req: HttpRequest, state: Data) -> Result { - let logged_user = logged_user_from_request(&req, &state.config.application.server_key)?; - logout(logged_user, state.user.clone()).await; - Ok(HttpResponse::Ok().finish()) + let logged_user = logged_user_from_request(&req, &state.config.application.server_key)?; + logout(logged_user, state.user.clone()).await; + Ok(HttpResponse::Ok().finish()) } #[tracing::instrument(level = "debug", skip(state))] async fn register_handler(req: Json, state: Data) -> Result { - let req = req.into_inner(); - let name = UserName::parse(req.name) - .map_err(InputParamsError::InvalidName)? - .0; - let email = UserEmail::parse(req.email) - .map_err(InputParamsError::InvalidEmail)? - .0; - let password = UserPassword::parse(req.password) - .map_err(|_| InputParamsError::InvalidPassword)? - .0; + let req = req.into_inner(); + let name = UserName::parse(req.name) + .map_err(InputParamsError::InvalidName)? + .0; + let email = UserEmail::parse(req.email) + .map_err(InputParamsError::InvalidEmail)? + .0; + let password = UserPassword::parse(req.password) + .map_err(|_| InputParamsError::InvalidPassword)? + .0; - let resp = register(name, email, password, &state).await?; - Ok(HttpResponse::Ok().json(resp)) + let resp = register(name, email, password, &state).await?; + Ok(HttpResponse::Ok().json(resp)) } async fn change_password_handler( - req: HttpRequest, - payload: Json, - // session: SessionToken, - state: Data, + req: HttpRequest, + payload: Json, + // session: SessionToken, + state: Data, ) -> Result { - let logged_user = logged_user_from_request(&req, &state.config.application.server_key)?; - let payload = payload.into_inner(); - if payload.new_password != payload.new_password_confirm { - return Err(InputParamsError::PasswordNotMatch.into()); - } + let logged_user = logged_user_from_request(&req, &state.config.application.server_key)?; + let payload = payload.into_inner(); + if payload.new_password != payload.new_password_confirm { + return Err(InputParamsError::PasswordNotMatch.into()); + } - let new_password = UserPassword::parse(payload.new_password) - .map_err(|_| InputParamsError::InvalidPassword)? - .0; + let new_password = UserPassword::parse(payload.new_password) + .map_err(|_| InputParamsError::InvalidPassword)? + .0; - change_password( - state.pg_pool.clone(), - logged_user.clone(), - payload.current_password, - new_password, - ) - .await?; + change_password( + state.pg_pool.clone(), + logged_user.clone(), + payload.current_password, + new_password, + ) + .await?; - Ok(HttpResponse::Ok().finish()) + Ok(HttpResponse::Ok().finish()) } diff --git a/src/api/ws.rs b/src/api/ws.rs index de8a0378..ed350e2b 100644 --- a/src/api/ws.rs +++ b/src/api/ws.rs @@ -1,32 +1,40 @@ use crate::component::auth::LoggedUser; -use crate::component::ws::{MessageReceivers, WSClient, WSServer}; use crate::state::State; use actix::Addr; use actix_web::web::{Data, Path, Payload}; use actix_web::{get, web, HttpRequest, HttpResponse, Result, Scope}; use actix_web_actors::ws; +use secrecy::Secret; +use websocket::entities::WSUser; +use websocket::{CollabServer, CollabSession}; pub fn ws_scope() -> Scope { - web::scope("/ws").service(establish_ws_connection) + web::scope("/ws").service(establish_ws_connection) } #[get("/{token}")] pub async fn establish_ws_connection( - request: HttpRequest, - payload: Payload, - token: Path, - state: Data, - server: Data>, - msg_receivers: Data, + request: HttpRequest, + payload: Payload, + token: Path, + state: Data, + server: Data>, ) -> Result { - tracing::info!("establish_ws_connection"); - let user = LoggedUser::from_token(&state.config.application.server_key, token.as_str())?; - let client = WSClient::new(user, server.get_ref().clone(), msg_receivers); - match ws::start(client, &request, payload) { - Ok(response) => Ok(response), - Err(e) => { - tracing::error!("ws connection error: {:?}", e); - Err(e) - } - } + let user = LoggedUser::from_token(&state.config.application.server_key, token.as_str())?; + let client = CollabSession::new(user.into(), server.get_ref().clone()); + match ws::start(client, &request, payload) { + Ok(response) => Ok(response), + Err(e) => { + tracing::error!("ws connection error: {:?}", e); + Err(e) + }, + } +} + +impl From for WSUser { + fn from(user: LoggedUser) -> Self { + Self { + user_id: Secret::new(user.expose_secret().to_string()), + } + } } diff --git a/src/application.rs b/src/application.rs index f5349d3e..41f0f6d3 100644 --- a/src/application.rs +++ b/src/application.rs @@ -10,6 +10,9 @@ use actix_session::SessionMiddleware; use actix_web::cookie::Key; use actix_web::{dev::Server, web, web::Data, App, HttpServer}; +use actix::Actor; + +use collab_persistence::kv::rocks_kv::RocksCollabDB; use openssl::ssl::{SslAcceptor, SslAcceptorBuilder, SslFiletype, SslMethod}; use openssl::x509::X509; use secrecy::{ExposeSecret, Secret}; @@ -18,121 +21,136 @@ use sqlx::{postgres::PgPoolOptions, PgPool}; use std::net::TcpListener; use std::sync::Arc; use tokio::sync::RwLock; + use tracing_actix_web::TracingLogger; +use websocket::CollabServer; pub struct Application { - port: u16, - server: Server, + port: u16, + server: Server, } impl Application { - pub async fn build(config: Config, state: State) -> Result { - let address = format!("{}:{}", config.application.host, config.application.port); - let listener = TcpListener::bind(&address)?; - let port = listener.local_addr().unwrap().port(); - let server = run(listener, state, config).await?; + pub async fn build(config: Config, state: State) -> Result { + let address = format!("{}:{}", config.application.host, config.application.port); + let listener = TcpListener::bind(&address)?; + let port = listener.local_addr().unwrap().port(); + let server = run(listener, state, config).await?; - Ok(Self { port, server }) - } + Ok(Self { port, server }) + } - pub async fn run_until_stopped(self) -> Result<(), std::io::Error> { - self.server.await - } + pub async fn run_until_stopped(self) -> Result<(), std::io::Error> { + self.server.await + } - pub fn port(&self) -> u16 { - self.port - } + pub fn port(&self) -> u16 { + self.port + } } pub async fn run( - listener: TcpListener, - state: State, - config: Config, + listener: TcpListener, + state: State, + config: Config, ) -> Result { - let redis_store = RedisSessionStore::new(config.redis_uri.expose_secret()) - .await - .map_err(|e| { - anyhow::anyhow!( - "Failed to connect to Redis at {:?}: {:?}", - config.redis_uri, - e - ) - })?; - let pair = get_certificate_and_server_key(&config); - let key = pair - .as_ref() - .map(|(_, server_key)| Key::from(server_key.expose_secret().as_bytes())) - .unwrap_or_else(Key::generate); - let mut server = HttpServer::new(move || { - App::new() - // Session middleware + let redis_store = RedisSessionStore::new(config.redis_uri.expose_secret()) + .await + .map_err(|e| { + anyhow::anyhow!( + "Failed to connect to Redis at {:?}: {:?}", + config.redis_uri, + e + ) + })?; + let pair = get_certificate_and_server_key(&config); + let key = pair + .as_ref() + .map(|(_, server_key)| Key::from(server_key.expose_secret().as_bytes())) + .unwrap_or_else(Key::generate); + + let collab_server = CollabServer::new(state.rocksdb.clone()).unwrap().start(); + let mut server = HttpServer::new(move || { + App::new() .wrap( SessionMiddleware::builder(redis_store.clone(), key.clone()) .cookie_name(HEADER_TOKEN.to_string()) .build(), ) + // .wrap(ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, add_error_header)) .wrap(IdentityMiddleware::default()) .wrap(default_cors()) .wrap(TracingLogger::default()) .app_data(web::JsonConfig::default().limit(4096)) .service(user_scope()) .service(ws_scope()) + .app_data(Data::new(collab_server.clone())) .app_data(Data::new(state.clone())) - }); + }); - server = match pair { - None => server.listen(listener)?, - Some((certificate, _)) => { - server.listen_openssl(listener, make_ssl_acceptor_builder(certificate))? - } - }; + server = match pair { + None => server.listen(listener)?, + Some((certificate, _)) => { + server.listen_openssl(listener, make_ssl_acceptor_builder(certificate))? + }, + }; - Ok(server.run()) + Ok(server.run()) } fn get_certificate_and_server_key(config: &Config) -> Option<(Secret, Secret)> { - let tls_config = config.application.tls_config.as_ref()?; - match tls_config { - TlsConfig::NoTls => None, - TlsConfig::SelfSigned => Some(create_self_signed_certificate().unwrap()), - } + let tls_config = config.application.tls_config.as_ref()?; + match tls_config { + TlsConfig::NoTls => None, + TlsConfig::SelfSigned => Some(create_self_signed_certificate().unwrap()), + } } pub async fn init_state(config: &Config) -> State { - let pg_pool = get_connection_pool(&config.database) - .await - .unwrap_or_else(|_| panic!("Failed to connect to Postgres at {:?}.", config.database)); + let pg_pool = get_connection_pool(&config.database) + .await + .unwrap_or_else(|_| panic!("Failed to connect to Postgres at {:?}.", config.database)); - State { - pg_pool, - config: Arc::new(config.clone()), - user: Arc::new(Default::default()), - id_gen: Arc::new(RwLock::new(Snowflake::new(1))), - } + std::fs::create_dir_all(config.application.rocksdb_db_dir()).expect("create rocksdb db dir"); + let rocksdb = Arc::new(RocksCollabDB::open(config.application.rocksdb_db_dir()).unwrap()); + State { + pg_pool, + rocksdb, + config: Arc::new(config.clone()), + user: Arc::new(Default::default()), + id_gen: Arc::new(RwLock::new(Snowflake::new(1))), + } } pub async fn get_connection_pool(setting: &DatabaseSetting) -> Result { - PgPoolOptions::new() - .acquire_timeout(std::time::Duration::from_secs(5)) - .connect_with(setting.with_db()) - .await + PgPoolOptions::new() + .acquire_timeout(std::time::Duration::from_secs(5)) + .connect_with(setting.with_db()) + .await } fn make_ssl_acceptor_builder(certificate: Secret) -> SslAcceptorBuilder { - let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - let x509_cert = X509::from_pem(certificate.expose_secret().as_bytes()).unwrap(); - builder.set_certificate(&x509_cert).unwrap(); - builder - .set_private_key_file("./cert/key.pem", SslFiletype::PEM) - .unwrap(); - builder - .set_certificate_chain_file("./cert/cert.pem") - .unwrap(); - builder - .set_min_proto_version(Some(openssl::ssl::SslVersion::TLS1_2)) - .unwrap(); - builder - .set_max_proto_version(Some(openssl::ssl::SslVersion::TLS1_3)) - .unwrap(); - builder + let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); + let x509_cert = X509::from_pem(certificate.expose_secret().as_bytes()).unwrap(); + builder.set_certificate(&x509_cert).unwrap(); + builder + .set_private_key_file("./cert/key.pem", SslFiletype::PEM) + .unwrap(); + builder + .set_certificate_chain_file("./cert/cert.pem") + .unwrap(); + builder + .set_min_proto_version(Some(openssl::ssl::SslVersion::TLS1_2)) + .unwrap(); + builder + .set_max_proto_version(Some(openssl::ssl::SslVersion::TLS1_3)) + .unwrap(); + builder } + +// fn add_error_header( +// res: dev::ServiceResponse, +// ) -> Result, actix_web::Error> { +// tracing::error!("{:?}", res.request()); +// Ok(ErrorHandlerResponse::Response(res.map_into_left_body())) +// } diff --git a/src/component/auth/error.rs b/src/component/auth/error.rs index 39ad47a0..ed96c678 100644 --- a/src/component/auth/error.rs +++ b/src/component/auth/error.rs @@ -5,76 +5,76 @@ use thiserror::Error; #[derive(Debug, Error)] pub enum AuthError { - #[error("Credentials is invalid")] - InvalidCredentials(#[source] anyhow::Error), + #[error("Credentials is invalid")] + InvalidCredentials(#[source] anyhow::Error), - #[error("User is not exist")] - UserNotExist(#[source] anyhow::Error), + #[error("User is not exist")] + UserNotExist(#[source] anyhow::Error), - #[error("{} is already used", email)] - UserAlreadyExist { email: String }, + #[error("{} is already used", email)] + UserAlreadyExist { email: String }, - #[error("Invalid password")] - InvalidPassword, + #[error("Invalid password")] + InvalidPassword, - #[error("User is unauthorized")] - Unauthorized, + #[error("User is unauthorized")] + Unauthorized, - #[error("User internal error")] - InternalError(#[from] anyhow::Error), + #[error("User internal error")] + InternalError(#[from] anyhow::Error), - #[error("Parser uuid failed: {}", err)] - InvalidUuid { err: String }, + #[error("Parser uuid failed: {}", err)] + InvalidUuid { err: String }, } pub fn internal_error(error: anyhow::Error) -> AuthError { - AuthError::InternalError(error) + AuthError::InternalError(error) } impl actix_web::error::ResponseError for AuthError { - fn status_code(&self) -> StatusCode { - match *self { - AuthError::InvalidCredentials(_) => StatusCode::UNAUTHORIZED, - AuthError::UserNotExist(_) => StatusCode::UNAUTHORIZED, - AuthError::UserAlreadyExist { .. } => StatusCode::BAD_REQUEST, - AuthError::InvalidPassword => StatusCode::UNAUTHORIZED, - AuthError::Unauthorized => StatusCode::UNAUTHORIZED, - AuthError::InternalError(_) => StatusCode::INTERNAL_SERVER_ERROR, - AuthError::InvalidUuid { .. } => StatusCode::UNAUTHORIZED, - } + fn status_code(&self) -> StatusCode { + match *self { + AuthError::InvalidCredentials(_) => StatusCode::UNAUTHORIZED, + AuthError::UserNotExist(_) => StatusCode::UNAUTHORIZED, + AuthError::UserAlreadyExist { .. } => StatusCode::BAD_REQUEST, + AuthError::InvalidPassword => StatusCode::UNAUTHORIZED, + AuthError::Unauthorized => StatusCode::UNAUTHORIZED, + AuthError::InternalError(_) => StatusCode::INTERNAL_SERVER_ERROR, + AuthError::InvalidUuid { .. } => StatusCode::UNAUTHORIZED, } + } - fn error_response(&self) -> HttpResponse { - HttpResponse::build(self.status_code()).body(self.to_string()) - } + fn error_response(&self) -> HttpResponse { + HttpResponse::build(self.status_code()).body(self.to_string()) + } } #[derive(Debug, Error)] pub enum InputParamsError { - #[error("Invalid name")] - InvalidName(String), + #[error("Invalid name")] + InvalidName(String), - #[error("Invalid email format")] - InvalidEmail(String), + #[error("Invalid email format")] + InvalidEmail(String), - #[error("Invalid password")] - InvalidPassword, + #[error("Invalid password")] + InvalidPassword, - #[error("You entered two different new passwords")] - PasswordNotMatch, + #[error("You entered two different new passwords")] + PasswordNotMatch, } impl actix_web::error::ResponseError for InputParamsError { - fn status_code(&self) -> StatusCode { - match *self { - InputParamsError::InvalidName(_) => StatusCode::BAD_REQUEST, - InputParamsError::InvalidEmail(_) => StatusCode::BAD_REQUEST, - InputParamsError::InvalidPassword => StatusCode::BAD_REQUEST, - InputParamsError::PasswordNotMatch => StatusCode::BAD_REQUEST, - } + fn status_code(&self) -> StatusCode { + match *self { + InputParamsError::InvalidName(_) => StatusCode::BAD_REQUEST, + InputParamsError::InvalidEmail(_) => StatusCode::BAD_REQUEST, + InputParamsError::InvalidPassword => StatusCode::BAD_REQUEST, + InputParamsError::PasswordNotMatch => StatusCode::BAD_REQUEST, } + } - fn error_response(&self) -> HttpResponse { - HttpResponse::build(self.status_code()).body(self.to_string()) - } + fn error_response(&self) -> HttpResponse { + HttpResponse::build(self.status_code()).body(self.to_string()) + } } diff --git a/src/component/auth/password.rs b/src/component/auth/password.rs index 39878664..2d829d54 100644 --- a/src/component/auth/password.rs +++ b/src/component/auth/password.rs @@ -8,84 +8,85 @@ use secrecy::{ExposeSecret, Secret}; use sqlx::PgPool; pub struct Credentials { - pub email: String, - pub password: Secret, + pub email: String, + pub password: Secret, } #[tracing::instrument(level = "debug", skip(credentials, pool))] pub async fn validate_credentials( - credentials: Credentials, - pool: &PgPool, + credentials: Credentials, + pool: &PgPool, ) -> Result { - let mut uid = None; - let mut expected_hash_password = Secret::new( - "$argon2id$v=19$m=15000,t=2,p=1$\ + let mut uid = None; + let mut expected_hash_password = Secret::new( + "$argon2id$v=19$m=15000,t=2,p=1$\ gZiV/M1gPc22ElAH/Jh1Hw$\ CWOrkoo7oJBQ/iyh7uJ0LO2aLEfrHwTWllSAxT0zRno" - .to_string(), - ); + .to_string(), + ); - if let Some((stored_uid, stored_hash_password)) = - get_stored_credentials(&credentials.email, pool).await? - { - uid = Some(stored_uid); - expected_hash_password = stored_hash_password; - } + if let Some((stored_uid, stored_hash_password)) = + get_stored_credentials(&credentials.email, pool).await? + { + uid = Some(stored_uid); + expected_hash_password = stored_hash_password; + } - spawn_blocking_with_tracing(move || { - verify_password_hash(expected_hash_password, credentials.password) - }) - .await - .context("Failed to spawn blocking task.")??; + spawn_blocking_with_tracing(move || { + verify_password_hash(expected_hash_password, credentials.password) + }) + .await + .context("Failed to spawn blocking task.")??; - uid.ok_or_else(|| anyhow::anyhow!("Unknown email.")) - .map_err(AuthError::InvalidCredentials) + uid + .ok_or_else(|| anyhow::anyhow!("Unknown email.")) + .map_err(AuthError::InvalidCredentials) } pub fn compute_hash_password(password: &[u8]) -> Result, anyhow::Error> { - let salt = SaltString::generate(&mut rand::thread_rng()); - let password = Argon2::new( - Algorithm::Argon2id, - Version::V0x13, - Params::new(15000, 2, 1, None).unwrap(), - ) - .hash_password(password, &salt)? - .to_string(); - Ok(Secret::new(password)) + let salt = SaltString::generate(&mut rand::thread_rng()); + let password = Argon2::new( + Algorithm::Argon2id, + Version::V0x13, + Params::new(15000, 2, 1, None).unwrap(), + ) + .hash_password(password, &salt)? + .to_string(); + Ok(Secret::new(password)) } #[tracing::instrument(level = "debug", skip(email, pool))] async fn get_stored_credentials( - email: &str, - pool: &PgPool, + email: &str, + pool: &PgPool, ) -> Result)>, anyhow::Error> { - let row = sqlx::query!( - r#" + let row = sqlx::query!( + r#" SELECT uid, password FROM users WHERE email = $1 "#, - email, - ) - .fetch_optional(pool) - .await - .context("Failed to performed a query to retrieve stored credentials.")? - .map(|row| (row.uid, Secret::new(row.password))); - Ok(row) + email, + ) + .fetch_optional(pool) + .await + .context("Failed to performed a query to retrieve stored credentials.")? + .map(|row| (row.uid, Secret::new(row.password))); + Ok(row) } fn verify_password_hash( - expected_password_hash: Secret, - password_candidate: Secret, + expected_password_hash: Secret, + password_candidate: Secret, ) -> Result<(), AuthError> { - let expected_hash_password = PasswordHash::new(expected_password_hash.expose_secret()) - .context("Failed to parse hash in PHC string format.")?; + let expected_hash_password = PasswordHash::new(expected_password_hash.expose_secret()) + .context("Failed to parse hash in PHC string format.")?; - Argon2::default() - .verify_password( - password_candidate.expose_secret().as_bytes(), - &expected_hash_password, - ) - .context("Invalid password.") - .map_err(|_| AuthError::InvalidPassword) + Argon2::default() + .verify_password( + password_candidate.expose_secret().as_bytes(), + &expected_hash_password, + ) + .context("Invalid password.") + .map_err(|_| AuthError::InvalidPassword) } diff --git a/src/component/auth/user.rs b/src/component/auth/user.rs index bac06560..d0a68f4e 100644 --- a/src/component/auth/user.rs +++ b/src/component/auth/user.rs @@ -1,5 +1,5 @@ use crate::component::auth::{ - compute_hash_password, internal_error, validate_credentials, AuthError, Credentials, + compute_hash_password, internal_error, validate_credentials, AuthError, Credentials, }; use crate::config::env::domain; use crate::state::{State, UserCache}; @@ -18,234 +18,234 @@ use token::{create_token, parse_token, TokenError}; use tokio::sync::RwLock; pub async fn login( - email: String, - password: String, - state: &State, + email: String, + password: String, + state: &State, ) -> Result<(LoginResponse, Secret), AuthError> { - let credentials = Credentials { - email, - password: Secret::new(password), - }; - let server_key = &state.config.application.server_key; + let credentials = Credentials { + email, + password: Secret::new(password), + }; + let server_key = &state.config.application.server_key; - match validate_credentials(credentials, &state.pg_pool).await { - Ok(uid) => { - let token = Token::create_token(uid, server_key)?; - let logged_user = LoggedUser::new(uid); - state.user.write().await.authorized(logged_user); - Ok(( - LoginResponse { - token: token.0.clone(), - uid: uid.to_string(), - }, - Secret::new(token), - )) - } - Err(err) => Err(err), - } + match validate_credentials(credentials, &state.pg_pool).await { + Ok(uid) => { + let token = Token::create_token(uid, server_key)?; + let logged_user = LoggedUser::new(uid); + state.user.write().await.authorized(logged_user); + Ok(( + LoginResponse { + token: token.0.clone(), + uid: uid.to_string(), + }, + Secret::new(token), + )) + }, + Err(err) => Err(err), + } } pub async fn logout(logged_user: LoggedUser, cache: Arc>) { - cache.write().await.unauthorized(logged_user); + cache.write().await.unauthorized(logged_user); } pub async fn register( - username: String, - email: String, - password: String, - state: &State, + username: String, + email: String, + password: String, + state: &State, ) -> Result { - let pg_pool = state.pg_pool.clone(); - let server_key = &state.config.application.server_key; - let mut transaction = pg_pool - .begin() - .await - .context("Failed to acquire a Postgres connection to register user") - .map_err(internal_error)?; + let pg_pool = state.pg_pool.clone(); + let server_key = &state.config.application.server_key; + let mut transaction = pg_pool + .begin() + .await + .context("Failed to acquire a Postgres connection to register user") + .map_err(internal_error)?; - if is_email_exist(&mut transaction, email.as_ref()) - .await - .map_err(internal_error)? - { - return Err(AuthError::UserAlreadyExist { email }); - } + if is_email_exist(&mut transaction, email.as_ref()) + .await + .map_err(internal_error)? + { + return Err(AuthError::UserAlreadyExist { email }); + } - let uid = state.id_gen.write().await.next_id(); - let token = Token::create_token(uid, server_key)?; - let password = compute_hash_password(password.as_bytes()).map_err(internal_error)?; - let _ = sqlx::query!( - r#" + let uid = state.id_gen.write().await.next_id(); + let token = Token::create_token(uid, server_key)?; + let password = compute_hash_password(password.as_bytes()).map_err(internal_error)?; + let _ = sqlx::query!( + r#" INSERT INTO users (uid, email, username, create_time, password) VALUES ($1, $2, $3, $4, $5) "#, - uid, - email, - username, - Utc::now(), - password.expose_secret(), - ) - .execute(&mut transaction) + uid, + email, + username, + Utc::now(), + password.expose_secret(), + ) + .execute(&mut transaction) + .await + .context("Save user to disk failed") + .map_err(internal_error)?; + + transaction + .commit() .await - .context("Save user to disk failed") + .context("Failed to commit SQL transaction to register user.") .map_err(internal_error)?; - transaction - .commit() - .await - .context("Failed to commit SQL transaction to register user.") - .map_err(internal_error)?; + let logged_user = LoggedUser::new(uid); + state.user.write().await.authorized(logged_user); - let logged_user = LoggedUser::new(uid); - state.user.write().await.authorized(logged_user); - - Ok(RegisterResponse { - token: token.0.clone(), - }) + Ok(RegisterResponse { + token: token.0.clone(), + }) } pub async fn change_password( - pg_pool: PgPool, - logged_user: LoggedUser, - current_password: String, - new_password: String, + pg_pool: PgPool, + logged_user: LoggedUser, + current_password: String, + new_password: String, ) -> Result<(), AuthError> { - let mut transaction = pg_pool - .begin() - .await - .context("Failed to acquire a Postgres connection to change password") - .map_err(internal_error)?; + let mut transaction = pg_pool + .begin() + .await + .context("Failed to acquire a Postgres connection to change password") + .map_err(internal_error)?; - let email = get_user_email(*logged_user.expose_secret(), &mut transaction).await?; + let email = get_user_email(*logged_user.expose_secret(), &mut transaction).await?; - // check password - let credentials = Credentials { - email, - password: Secret::new(current_password), - }; - let _ = validate_credentials(credentials, &pg_pool).await?; + // check password + let credentials = Credentials { + email, + password: Secret::new(current_password), + }; + let _ = validate_credentials(credentials, &pg_pool).await?; - // Hash password - let new_hash_password = - spawn_blocking_with_tracing(move || compute_hash_password(new_password.as_bytes())) - .await - .context("Failed to hash password")??; + // Hash password + let new_hash_password = + spawn_blocking_with_tracing(move || compute_hash_password(new_password.as_bytes())) + .await + .context("Failed to hash password")??; - // Save password to disk - let sql = "UPDATE users SET password = $1 where uid = $2"; - let _ = sqlx::query(sql) - .bind(new_hash_password.expose_secret()) - .bind(logged_user.expose_secret()) - .execute(&mut transaction) - .await - .context("Failed to change user's password in the database.")?; + // Save password to disk + let sql = "UPDATE users SET password = $1 where uid = $2"; + let _ = sqlx::query(sql) + .bind(new_hash_password.expose_secret()) + .bind(logged_user.expose_secret()) + .execute(&mut transaction) + .await + .context("Failed to change user's password in the database.")?; - transaction - .commit() - .await - .context("Failed to commit SQL transaction to change user's password") - .map_err(internal_error)?; - Ok(()) + transaction + .commit() + .await + .context("Failed to commit SQL transaction to change user's password") + .map_err(internal_error)?; + Ok(()) } pub async fn get_user_email( - uid: i64, - transaction: &mut Transaction<'_, Postgres>, + uid: i64, + transaction: &mut Transaction<'_, Postgres>, ) -> Result { - let row = sqlx::query!( - r#" + let row = sqlx::query!( + r#" SELECT email FROM users WHERE uid = $1 "#, - uid, - ) - .fetch_one(transaction) - .await - .context("Failed to retrieve the username`")?; - Ok(row.email) + uid, + ) + .fetch_one(transaction) + .await + .context("Failed to retrieve the username`")?; + Ok(row.email) } /// TODO: cache this state in [State] async fn is_email_exist( - transaction: &mut Transaction<'_, Postgres>, - email: &str, + transaction: &mut Transaction<'_, Postgres>, + email: &str, ) -> Result { - let result = sqlx::query(r#"SELECT email FROM users WHERE email = $1"#) - .bind(email) - .fetch_optional(transaction) - .await?; + let result = sqlx::query(r#"SELECT email FROM users WHERE email = $1"#) + .bind(email) + .fetch_optional(transaction) + .await?; - Ok(result.is_some()) + Ok(result.is_some()) } #[derive(Default, Deserialize, Debug)] pub struct LoginRequest { - pub email: String, - pub password: String, + pub email: String, + pub password: String, } #[derive(Default, Serialize, Deserialize, Debug)] pub struct LoginResponse { - pub token: String, - pub uid: String, + pub token: String, + pub uid: String, } #[derive(Default, Deserialize, Debug)] pub struct RegisterRequest { - pub email: String, - pub password: String, - pub name: String, + pub email: String, + pub password: String, + pub name: String, } #[derive(Default, Serialize, Deserialize, Debug)] pub struct RegisterResponse { - pub token: String, + pub token: String, } #[derive(Default, Deserialize, Debug)] pub struct ChangePasswordRequest { - pub current_password: String, - pub new_password: String, - pub new_password_confirm: String, + pub current_password: String, + pub new_password: String, + pub new_password_confirm: String, } #[derive(Clone, Default)] -pub struct WrapI64(i64); -impl Copy for WrapI64 {} -impl DefaultIsZeroes for WrapI64 {} -impl DebugSecret for WrapI64 {} -impl CloneableSecret for WrapI64 {} +pub struct SecretI64(i64); +impl Copy for SecretI64 {} +impl DefaultIsZeroes for SecretI64 {} +impl DebugSecret for SecretI64 {} +impl CloneableSecret for SecretI64 {} -impl std::ops::Deref for WrapI64 { - type Target = i64; +impl std::ops::Deref for SecretI64 { + type Target = i64; - fn deref(&self) -> &Self::Target { - &self.0 - } + fn deref(&self) -> &Self::Target { + &self.0 + } } #[derive(Debug, Clone)] -pub struct LoggedUser(Secret); +pub struct LoggedUser(Secret); impl From for LoggedUser { - fn from(c: Claim) -> Self { - Self(Secret::new(WrapI64(c.uid))) - } + fn from(c: Claim) -> Self { + Self(Secret::new(SecretI64(c.uid))) + } } impl LoggedUser { - pub fn new(uid: i64) -> Self { - Self(Secret::new(WrapI64(uid))) - } + pub fn new(uid: i64) -> Self { + Self(Secret::new(SecretI64(uid))) + } - pub fn from_token(server_key: &Secret, token: &str) -> Result { - let user: LoggedUser = Token::decode_token(server_key, token)?.into(); - Ok(user) - } + pub fn from_token(server_key: &Secret, token: &str) -> Result { + let user: LoggedUser = Token::decode_token(server_key, token)?.into(); + Ok(user) + } - pub fn expose_secret(&self) -> &i64 { - self.0.expose_secret() - } + pub fn expose_secret(&self) -> &i64 { + self.0.expose_secret() + } } pub const HEADER_TOKEN: &str = "token"; @@ -253,68 +253,68 @@ pub const EXPIRED_DURATION_DAYS: i64 = 30; #[derive(Debug, Serialize, Deserialize)] pub struct Claim { - iss: String, - uid: i64, + iss: String, + uid: i64, } impl Claim { - pub fn with_user_id(uid: i64) -> Self { - Self { iss: domain(), uid } - } + pub fn with_user_id(uid: i64) -> Self { + Self { iss: domain(), uid } + } } #[derive(Clone, Default, Serialize, Deserialize)] pub struct Token(pub String); impl Zeroize for Token { - fn zeroize(&mut self) { - self.0.zeroize() - } + fn zeroize(&mut self) { + self.0.zeroize() + } } impl Token { - pub fn create_token(uid: i64, server_key: &Secret) -> Result { - let claim = Claim::with_user_id(uid); - let token = create_token( - server_key.expose_secret().as_str(), - claim, - Duration::days(EXPIRED_DURATION_DAYS), - ) - .map_err(|e| match e { - TokenError::Jwt(_) => AuthError::Unauthorized, - TokenError::Expired => AuthError::Unauthorized, - })?; - Ok(Self(token)) - } + pub fn create_token(uid: i64, server_key: &Secret) -> Result { + let claim = Claim::with_user_id(uid); + let token = create_token( + server_key.expose_secret().as_str(), + claim, + Duration::days(EXPIRED_DURATION_DAYS), + ) + .map_err(|e| match e { + TokenError::Jwt(_) => AuthError::Unauthorized, + TokenError::Expired => AuthError::Unauthorized, + })?; + Ok(Self(token)) + } - pub fn decode_token(server_key: &Secret, token: &str) -> Result { - parse_token::(server_key.expose_secret().as_str(), token) - .map_err(|_| AuthError::Unauthorized) - } + pub fn decode_token(server_key: &Secret, token: &str) -> Result { + parse_token::(server_key.expose_secret().as_str(), token) + .map_err(|_| AuthError::Unauthorized) + } } pub fn logged_user_from_request( - request: &HttpRequest, - server_key: &Secret, + request: &HttpRequest, + server_key: &Secret, ) -> Result { - match request.headers().get(HEADER_TOKEN) { - None => Err(AuthError::Unauthorized), - Some(header) => match header.to_str() { - Ok(token_str) => LoggedUser::from_token(server_key, token_str), - Err(_) => Err(AuthError::Unauthorized), - }, - } + match request.headers().get(HEADER_TOKEN) { + None => Err(AuthError::Unauthorized), + Some(header) => match header.to_str() { + Ok(token_str) => LoggedUser::from_token(server_key, token_str), + Err(_) => Err(AuthError::Unauthorized), + }, + } } pub fn uid_from_request( - request: &HttpRequest, - server_key: &Secret, + request: &HttpRequest, + server_key: &Secret, ) -> Result, AuthError> { - match request.headers().get(HEADER_TOKEN) { - Some(header) => match header.to_str() { - Ok(val) => Token::decode_token(server_key, val).map(|claim| Secret::new(claim.uid)), - Err(_) => Err(AuthError::Unauthorized), - }, - None => Err(AuthError::Unauthorized), - } + match request.headers().get(HEADER_TOKEN) { + Some(header) => match header.to_str() { + Ok(val) => Token::decode_token(server_key, val).map(|claim| Secret::new(claim.uid)), + Err(_) => Err(AuthError::Unauthorized), + }, + None => Err(AuthError::Unauthorized), + } } diff --git a/src/component/mod.rs b/src/component/mod.rs index adfbdbb9..79bb8910 100644 --- a/src/component/mod.rs +++ b/src/component/mod.rs @@ -1,3 +1,2 @@ pub mod auth; pub mod token_state; -pub mod ws; diff --git a/src/component/token_state.rs b/src/component/token_state.rs index 4a50a1c7..b60506b0 100644 --- a/src/component/token_state.rs +++ b/src/component/token_state.rs @@ -9,30 +9,30 @@ use std::future::{ready, Ready}; pub struct SessionToken(Session); impl SessionToken { - const TOKEN_ID_KEY: &'static str = "token"; + const TOKEN_ID_KEY: &'static str = "token"; - pub fn renew(&self) { - self.0.renew(); - } + pub fn renew(&self) { + self.0.renew(); + } - pub fn insert_token(&self, token: Secret) -> Result<(), SessionInsertError> { - self.0.insert(Self::TOKEN_ID_KEY, token.expose_secret()) - } + pub fn insert_token(&self, token: Secret) -> Result<(), SessionInsertError> { + self.0.insert(Self::TOKEN_ID_KEY, token.expose_secret()) + } - pub fn get_token(&self) -> Result, SessionGetError> { - self.0.get(Self::TOKEN_ID_KEY) - } + pub fn get_token(&self) -> Result, SessionGetError> { + self.0.get(Self::TOKEN_ID_KEY) + } - pub fn log_out(self) { - self.0.purge() - } + pub fn log_out(self) { + self.0.purge() + } } impl FromRequest for SessionToken { - type Error = ::Error; - type Future = Ready>; + type Error = ::Error; + type Future = Ready>; - fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future { - ready(Ok(SessionToken(req.get_session()))) - } + fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future { + ready(Ok(SessionToken(req.get_session()))) + } } diff --git a/src/component/ws/client.rs b/src/component/ws/client.rs deleted file mode 100644 index cb4092d3..00000000 --- a/src/component/ws/client.rs +++ /dev/null @@ -1,164 +0,0 @@ -use crate::component::auth::LoggedUser; -use crate::component::ws::entities::{ - Connect, Disconnect, MessageDetail, MessagePayload, Socket, WebSocketMessage, -}; -use crate::component::ws::server::WSServer; -use crate::component::ws::{HEARTBEAT_INTERVAL, PING_TIMEOUT}; -use actix::*; -use actix_http::ws::Message::*; -use actix_web::web::Data; -use actix_web_actors::ws; -use bytes::Bytes; -use std::collections::HashMap; -use std::sync::Arc; -use std::time::Instant; - -pub trait MessageReceiver: Send + Sync { - fn receive(&self, data: WSClientData); -} - -#[derive(Default)] -pub struct MessageReceivers { - inner: HashMap>, -} - -impl MessageReceivers { - pub fn new() -> Self { - MessageReceivers::default() - } - - pub fn insert(&mut self, channel: u8, receiver: Arc) { - self.inner.insert(channel, receiver); - } - - pub fn get(&self, source: u8) -> Option<&Arc> { - self.inner.get(&source) - } -} - -#[allow(dead_code)] -pub struct WSClientData { - pub(crate) socket: Socket, - pub(crate) detail: MessageDetail, -} - -pub struct WSClient { - user: Arc, - server: Addr, - msg_receivers: Data, - hb: Instant, -} - -impl WSClient { - pub fn new( - user: LoggedUser, - server: Addr, - msg_receivers: Data, - ) -> Self { - Self { - user: Arc::new(user), - server, - msg_receivers, - hb: Instant::now(), - } - } - - fn hb(&self, ctx: &mut ws::WebsocketContext) { - ctx.run_interval(HEARTBEAT_INTERVAL, |client, ctx| { - if Instant::now().duration_since(client.hb) > PING_TIMEOUT { - client.server.do_send(Disconnect { - user: client.user.clone(), - }); - ctx.stop(); - } else { - ctx.ping(b""); - } - }); - } - - fn handle_binary_message(&self, bytes: Bytes, socket: Socket) { - let MessagePayload { channel, detail } = MessagePayload::from_bytes(&bytes); - match self.msg_receivers.get(channel) { - None => { - tracing::error!("Can't find the receiver for {:?}", channel); - } - Some(handler) => { - let client_data = WSClientData { socket, detail }; - handler.receive(client_data); - } - } - } -} - -impl StreamHandler> for WSClient { - fn handle(&mut self, msg: Result, ctx: &mut Self::Context) { - match msg { - Ok(Ping(msg)) => { - self.hb = Instant::now(); - ctx.pong(&msg); - } - Ok(Pong(_msg)) => { - // tracing::trace!("Receive {} pong {:?}", &self.session_id, &msg); - self.hb = Instant::now(); - } - Ok(Binary(bytes)) => { - let socket = ctx.address().recipient(); - self.handle_binary_message(bytes, socket); - } - Ok(Text(_)) => { - tracing::warn!("Receive unexpected text message"); - } - Ok(Close(reason)) => { - ctx.close(reason); - ctx.stop(); - } - Ok(ws::Message::Continuation(_)) => {} - Ok(ws::Message::Nop) => {} - Err(e) => { - tracing::error!("WebSocketStream protocol error {:?}", e); - ctx.stop(); - } - } - } -} - -impl Handler for WSClient { - type Result = (); - - fn handle(&mut self, msg: WebSocketMessage, ctx: &mut Self::Context) { - ctx.binary(msg.0); - } -} - -impl Actor for WSClient { - type Context = ws::WebsocketContext; - - fn started(&mut self, ctx: &mut Self::Context) { - self.hb(ctx); - let socket = ctx.address().recipient(); - let connect = Connect { - socket, - user: self.user.clone(), - }; - self.server - .send(connect) - .into_actor(self) - .then(|res, _client, _ctx| { - match res { - Ok(Ok(_)) => tracing::trace!("Send connect message to server success"), - Ok(Err(e)) => tracing::error!("Send connect message to server failed: {:?}", e), - Err(e) => tracing::error!("Send connect message to server failed: {:?}", e), - } - fut::ready(()) - }) - .wait(ctx); - } - - fn stopping(&mut self, _: &mut Self::Context) -> Running { - self.server.do_send(Disconnect { - user: self.user.clone(), - }); - - Running::Stop - } -} diff --git a/src/component/ws/entities.rs b/src/component/ws/entities.rs deleted file mode 100644 index 6915d672..00000000 --- a/src/component/ws/entities.rs +++ /dev/null @@ -1,92 +0,0 @@ -use crate::component::auth::LoggedUser; -use actix::{Message, Recipient}; -use bytes::Bytes; -use serde::{Deserialize, Serialize}; -use std::fmt::Formatter; -use std::sync::Arc; - -pub type Socket = Recipient; - -#[derive(Serialize, Deserialize, Debug, Clone, Hash, PartialEq, Eq)] -pub struct WSSessionId(pub String); - -impl> std::convert::From for WSSessionId { - fn from(s: T) -> Self { - WSSessionId(s.as_ref().to_owned()) - } -} - -impl std::fmt::Display for WSSessionId { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - let desc = &self.0.to_string(); - f.write_str(desc) - } -} - -pub struct Session { - pub user: Arc, - pub socket: Socket, -} - -impl std::convert::From for Session { - fn from(c: Connect) -> Self { - Self { - user: c.user, - socket: c.socket, - } - } -} - -#[derive(Debug, Message, Clone)] -#[rtype(result = "Result<(), WSError>")] -pub struct Connect { - pub socket: Socket, - pub user: Arc, -} - -#[derive(Debug, Message, Clone)] -#[rtype(result = "Result<(), WSError>")] -pub struct Disconnect { - pub user: Arc, -} - -#[derive(Debug, Message, Clone)] -#[rtype(result = "()")] -pub struct WebSocketMessage(pub Bytes); - -impl std::ops::Deref for WebSocketMessage { - type Target = Bytes; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct MessagePayload { - pub(crate) channel: u8, - pub(crate) detail: MessageDetail, -} - -impl MessagePayload { - pub fn from_bytes>(bytes: T) -> Self { - serde_json::from_slice(bytes.as_ref()).unwrap() - } -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(tag = "type")] -pub enum MessageDetail { - Document(MessageContent), - Database(MessageContent), -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct MessageContent { - content: String, -} - -#[derive(Debug)] -pub enum WSError { - Internal, -} diff --git a/src/component/ws/mod.rs b/src/component/ws/mod.rs deleted file mode 100644 index 65125f13..00000000 --- a/src/component/ws/mod.rs +++ /dev/null @@ -1,11 +0,0 @@ -use std::time::Duration; - -mod client; -mod entities; -mod server; - -pub use client::*; -pub use server::WSServer; - -pub(crate) const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(8); -pub(crate) const PING_TIMEOUT: Duration = Duration::from_secs(60); diff --git a/src/component/ws/server.rs b/src/component/ws/server.rs deleted file mode 100644 index 10a384fc..00000000 --- a/src/component/ws/server.rs +++ /dev/null @@ -1,44 +0,0 @@ -use crate::component::ws::entities::{Connect, Disconnect, WSError, WebSocketMessage}; -use actix::{Actor, Context, Handler}; - -#[derive(Default)] -pub struct WSServer {} - -impl WSServer { - pub fn new() -> Self { - WSServer::default() - } - - pub fn send(&self, _msg: WebSocketMessage) {} -} - -impl Actor for WSServer { - type Context = Context; - fn started(&mut self, _ctx: &mut Self::Context) {} -} - -impl Handler for WSServer { - type Result = Result<(), WSError>; - fn handle(&mut self, _msg: Connect, _ctx: &mut Context) -> Self::Result { - Ok(()) - } -} - -impl Handler for WSServer { - type Result = Result<(), WSError>; - fn handle(&mut self, _msg: Disconnect, _: &mut Context) -> Self::Result { - Ok(()) - } -} - -impl Handler for WSServer { - type Result = (); - - fn handle(&mut self, _msg: WebSocketMessage, _ctx: &mut Context) -> Self::Result {} -} - -impl actix::Supervised for WSServer { - fn restarting(&mut self, _ctx: &mut Context) { - tracing::warn!("restarting"); - } -} diff --git a/src/config/config.rs b/src/config/config.rs index 818b2156..d0da1972 100644 --- a/src/config/config.rs +++ b/src/config/config.rs @@ -3,12 +3,13 @@ use secrecy::Secret; use serde_aux::field_attributes::deserialize_number_from_string; use sqlx::postgres::{PgConnectOptions, PgSslMode}; use std::convert::{TryFrom, TryInto}; +use std::path::PathBuf; #[derive(serde::Deserialize, Clone, Debug)] pub struct Config { - pub database: DatabaseSetting, - pub application: ApplicationSettings, - pub redis_uri: Secret, + pub database: DatabaseSetting, + pub application: ApplicationSettings, + pub redis_uri: Secret, } // We are using 127.0.0.1 as our host in address, we are instructing our @@ -21,73 +22,78 @@ pub struct Config { // #[derive(serde::Deserialize, Clone, Debug)] pub struct ApplicationSettings { - #[serde(deserialize_with = "deserialize_number_from_string")] - pub port: u16, - pub host: String, - pub server_key: Secret, - pub tls_config: Option, + #[serde(deserialize_with = "deserialize_number_from_string")] + pub port: u16, + pub host: String, + pub data_dir: PathBuf, + pub server_key: Secret, + pub tls_config: Option, } impl ApplicationSettings { - pub fn use_https(&self) -> bool { - match &self.tls_config { - None => false, - Some(config) => match config { - TlsConfig::NoTls => false, - TlsConfig::SelfSigned => true, - }, - } + pub fn use_https(&self) -> bool { + match &self.tls_config { + None => false, + Some(config) => match config { + TlsConfig::NoTls => false, + TlsConfig::SelfSigned => true, + }, } + } + + pub fn rocksdb_db_dir(&self) -> PathBuf { + self.data_dir.join("rocksdb") + } } #[derive(serde::Deserialize, Clone, Debug)] #[serde(rename_all = "snake_case")] pub enum TlsConfig { - NoTls, - SelfSigned, + NoTls, + SelfSigned, } #[derive(serde::Deserialize, Clone, Debug)] pub struct DatabaseSetting { - pub username: String, - pub password: String, - #[serde(deserialize_with = "deserialize_number_from_string")] - pub port: u16, - pub host: String, - pub database_name: String, - pub require_ssl: bool, + pub username: String, + pub password: String, + #[serde(deserialize_with = "deserialize_number_from_string")] + pub port: u16, + pub host: String, + pub database_name: String, + pub require_ssl: bool, } impl DatabaseSetting { - pub fn without_db(&self) -> PgConnectOptions { - let ssl_mode = if self.require_ssl { - PgSslMode::Require - } else { - PgSslMode::Prefer - }; - PgConnectOptions::new() - .host(&self.host) - .username(&self.username) - .password(&self.password) - .port(self.port) - .ssl_mode(ssl_mode) - } + pub fn without_db(&self) -> PgConnectOptions { + let ssl_mode = if self.require_ssl { + PgSslMode::Require + } else { + PgSslMode::Prefer + }; + PgConnectOptions::new() + .host(&self.host) + .username(&self.username) + .password(&self.password) + .port(self.port) + .ssl_mode(ssl_mode) + } - pub fn with_db(&self) -> PgConnectOptions { - self.without_db().database(&self.database_name) - } + pub fn with_db(&self) -> PgConnectOptions { + self.without_db().database(&self.database_name) + } } pub fn get_configuration() -> Result { - let base_path = std::env::current_dir().expect("Failed to determine the current directory"); - let configuration_dir = base_path.join("configuration"); + let base_path = std::env::current_dir().expect("Failed to determine the current directory"); + let configuration_dir = base_path.join("configuration"); - let environment: Environment = std::env::var("APP_ENVIRONMENT") - .unwrap_or_else(|_| "local".into()) - .try_into() - .expect("Failed to parse APP_ENVIRONMENT."); + let environment: Environment = std::env::var("APP_ENVIRONMENT") + .unwrap_or_else(|_| "local".into()) + .try_into() + .expect("Failed to parse APP_ENVIRONMENT."); - let builder = InnerConfig::builder() + let builder = InnerConfig::builder() .set_default("default", "1")? .add_source( config::File::from(configuration_dir.join("base")) @@ -104,36 +110,36 @@ pub fn get_configuration() -> Result { // `Settings.application.port` .add_source(config::Environment::with_prefix("app").separator("__")); - let config = builder.build()?; - config.try_deserialize() + let config = builder.build()?; + config.try_deserialize() } /// The possible runtime environment for our application. pub enum Environment { - Local, - Production, + Local, + Production, } impl Environment { - pub fn as_str(&self) -> &'static str { - match self { - Environment::Local => "local", - Environment::Production => "production", - } + pub fn as_str(&self) -> &'static str { + match self { + Environment::Local => "local", + Environment::Production => "production", } + } } impl TryFrom for Environment { - type Error = String; + type Error = String; - fn try_from(s: String) -> Result { - match s.to_lowercase().as_str() { - "local" => Ok(Self::Local), - "production" => Ok(Self::Production), - other => Err(format!( - "{} is not a supported environment. Use either `local` or `production`.", - other - )), - } + fn try_from(s: String) -> Result { + match s.to_lowercase().as_str() { + "local" => Ok(Self::Local), + "production" => Ok(Self::Production), + other => Err(format!( + "{} is not a supported environment. Use either `local` or `production`.", + other + )), } + } } diff --git a/src/config/env.rs b/src/config/env.rs index c37fa6d2..f9ae65c1 100644 --- a/src/config/env.rs +++ b/src/config/env.rs @@ -1,13 +1,13 @@ use std::env; pub fn domain() -> String { - env::var("DOMAIN").unwrap_or_else(|_| "localhost".to_string()) + env::var("DOMAIN").unwrap_or_else(|_| "localhost".to_string()) } pub fn jwt_secret() -> String { - env::var("JWT_SECRET").unwrap_or_else(|_| "my secret".into()) + env::var("JWT_SECRET").unwrap_or_else(|_| "my secret".into()) } pub fn secret() -> String { - env::var("SECRET_KEY").unwrap_or_else(|_| "0123".repeat(8)) + env::var("SECRET_KEY").unwrap_or_else(|_| "0123".repeat(8)) } diff --git a/src/domain/user_email.rs b/src/domain/user_email.rs index 1b5dbfb8..1f45c9db 100644 --- a/src/domain/user_email.rs +++ b/src/domain/user_email.rs @@ -4,44 +4,44 @@ use validator::validate_email; pub struct UserEmail(pub String); impl UserEmail { - pub fn parse(s: String) -> Result { - if s.trim().is_empty() { - return Err("Email can not be empty or whitespace".to_string()); - } - - if validate_email(&s) { - Ok(Self(s)) - } else { - Err("Invalid email".to_string()) - } + pub fn parse(s: String) -> Result { + if s.trim().is_empty() { + return Err("Email can not be empty or whitespace".to_string()); } + + if validate_email(&s) { + Ok(Self(s)) + } else { + Err("Invalid email".to_string()) + } + } } impl AsRef for UserEmail { - fn as_ref(&self) -> &str { - &self.0 - } + fn as_ref(&self) -> &str { + &self.0 + } } #[cfg(test)] mod tests { - use super::*; + use super::*; - #[test] - fn empty_string_is_rejected() { - let email = "".to_string(); - assert!(UserEmail::parse(email).is_err()); - } + #[test] + fn empty_string_is_rejected() { + let email = "".to_string(); + assert!(UserEmail::parse(email).is_err()); + } - #[test] - fn email_missing_at_symbol_is_rejected() { - let email = "helloworld.com".to_string(); - assert!(UserEmail::parse(email).is_err()); - } + #[test] + fn email_missing_at_symbol_is_rejected() { + let email = "helloworld.com".to_string(); + assert!(UserEmail::parse(email).is_err()); + } - #[test] - fn email_missing_subject_is_rejected() { - let email = "@domain.com".to_string(); - assert!(UserEmail::parse(email).is_err()); - } + #[test] + fn email_missing_subject_is_rejected() { + let email = "@domain.com".to_string(); + assert!(UserEmail::parse(email).is_err()); + } } diff --git a/src/domain/user_name.rs b/src/domain/user_name.rs index 74209f20..b1bfc104 100644 --- a/src/domain/user_name.rs +++ b/src/domain/user_name.rs @@ -4,78 +4,78 @@ use unicode_segmentation::UnicodeSegmentation; pub struct UserName(pub String); impl UserName { - pub fn parse(s: String) -> Result { - let is_empty_or_whitespace = s.trim().is_empty(); - if is_empty_or_whitespace { - return Err("User name can not be empty or whitespace".to_string()); - } - // A grapheme is defined by the Unicode standard as a "user-perceived" - // character: `å` is a single grapheme, but it is composed of two characters - // (`a` and `̊`). - // - // `graphemes` returns an iterator over the graphemes in the input `s`. - // `true` specifies that we want to use the extended grapheme definition set, - // the recommended one. - let is_too_long = s.graphemes(true).count() > 256; - if is_too_long { - return Err("User name is too long".to_string()); - } - - let forbidden_characters = ['/', '(', ')', '"', '<', '>', '\\', '{', '}']; - let contains_forbidden_characters = s.chars().any(|g| forbidden_characters.contains(&g)); - - if contains_forbidden_characters { - return Err("User name contains invalid characters".to_string()); - } - Ok(Self(s)) + pub fn parse(s: String) -> Result { + let is_empty_or_whitespace = s.trim().is_empty(); + if is_empty_or_whitespace { + return Err("User name can not be empty or whitespace".to_string()); } + // A grapheme is defined by the Unicode standard as a "user-perceived" + // character: `å` is a single grapheme, but it is composed of two characters + // (`a` and `̊`). + // + // `graphemes` returns an iterator over the graphemes in the input `s`. + // `true` specifies that we want to use the extended grapheme definition set, + // the recommended one. + let is_too_long = s.graphemes(true).count() > 256; + if is_too_long { + return Err("User name is too long".to_string()); + } + + let forbidden_characters = ['/', '(', ')', '"', '<', '>', '\\', '{', '}']; + let contains_forbidden_characters = s.chars().any(|g| forbidden_characters.contains(&g)); + + if contains_forbidden_characters { + return Err("User name contains invalid characters".to_string()); + } + Ok(Self(s)) + } } impl AsRef for UserName { - fn as_ref(&self) -> &str { - &self.0 - } + fn as_ref(&self) -> &str { + &self.0 + } } #[cfg(test)] mod tests { - use super::UserName; + use super::UserName; - #[test] - fn a_256_grapheme_long_name_is_valid() { - let name = "a̐".repeat(256); - assert!(UserName::parse(name).is_ok()); - } + #[test] + fn a_256_grapheme_long_name_is_valid() { + let name = "a̐".repeat(256); + assert!(UserName::parse(name).is_ok()); + } - #[test] - fn a_name_longer_than_256_graphemes_is_rejected() { - let name = "a".repeat(257); - assert!(UserName::parse(name).is_err()); - } + #[test] + fn a_name_longer_than_256_graphemes_is_rejected() { + let name = "a".repeat(257); + assert!(UserName::parse(name).is_err()); + } - #[test] - fn whitespace_only_names_are_rejected() { - let name = " ".to_string(); - assert!(UserName::parse(name).is_err()); - } + #[test] + fn whitespace_only_names_are_rejected() { + let name = " ".to_string(); + assert!(UserName::parse(name).is_err()); + } - #[test] - fn empty_string_is_rejected() { - let name = "".to_string(); - assert!(UserName::parse(name).is_err()); - } + #[test] + fn empty_string_is_rejected() { + let name = "".to_string(); + assert!(UserName::parse(name).is_err()); + } - #[test] - fn names_containing_an_invalid_character_are_rejected() { - for name in &['/', '(', ')', '"', '<', '>', '\\', '{', '}'] { - let name = name.to_string(); - assert!(UserName::parse(name).is_err()); - } + #[test] + fn names_containing_an_invalid_character_are_rejected() { + for name in &['/', '(', ')', '"', '<', '>', '\\', '{', '}'] { + let name = name.to_string(); + assert!(UserName::parse(name).is_err()); } + } - #[test] - fn a_valid_name_is_parsed_successfully() { - let name = "nathan".to_string(); - assert!(UserName::parse(name).is_ok()); - } + #[test] + fn a_valid_name_is_parsed_successfully() { + let name = "nathan".to_string(); + assert!(UserName::parse(name).is_ok()); + } } diff --git a/src/domain/user_password.rs b/src/domain/user_password.rs index c98be39c..de3a0dcd 100644 --- a/src/domain/user_password.rs +++ b/src/domain/user_password.rs @@ -6,33 +6,33 @@ use unicode_segmentation::UnicodeSegmentation; pub struct UserPassword(pub String); impl UserPassword { - pub fn parse(s: String) -> Result { - if s.trim().is_empty() { - return Err("User password can not be empty or whitespace".to_owned()); - } - - if s.graphemes(true).count() > 100 { - return Err("Password is too long".to_owned()); - } - - let forbidden_characters = ['/', '(', ')', '"', '<', '>', '\\', '{', '}']; - let contains_forbidden_characters = s.chars().any(|g| forbidden_characters.contains(&g)); - if contains_forbidden_characters { - return Err("Password contains invalid characters".to_string()); - } - - if !validate_password(&s) { - return Err("Password format invalid".to_string()); - } - - Ok(Self(s)) + pub fn parse(s: String) -> Result { + if s.trim().is_empty() { + return Err("User password can not be empty or whitespace".to_owned()); } + + if s.graphemes(true).count() > 100 { + return Err("Password is too long".to_owned()); + } + + let forbidden_characters = ['/', '(', ')', '"', '<', '>', '\\', '{', '}']; + let contains_forbidden_characters = s.chars().any(|g| forbidden_characters.contains(&g)); + if contains_forbidden_characters { + return Err("Password contains invalid characters".to_string()); + } + + if !validate_password(&s) { + return Err("Password format invalid".to_string()); + } + + Ok(Self(s)) + } } impl AsRef for UserPassword { - fn as_ref(&self) -> &str { - &self.0 - } + fn as_ref(&self) -> &str { + &self.0 + } } lazy_static! { @@ -53,11 +53,11 @@ lazy_static! { } pub fn validate_password(password: &str) -> bool { - match PASSWORD.is_match(password) { - Ok(is_match) => is_match, - Err(e) => { - tracing::error!("validate_password fail: {:?}", e); - false - } - } + match PASSWORD.is_match(password) { + Ok(is_match) => is_match, + Err(e) => { + tracing::error!("validate_password fail: {:?}", e); + false + }, + } } diff --git a/src/main.rs b/src/main.rs index 587d05e2..4704a815 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,13 +4,17 @@ use appflowy_server::telemetry::{get_subscriber, init_subscriber}; #[actix_web::main] async fn main() -> anyhow::Result<()> { - let subscriber = get_subscriber("appflowy_server".into(), "info".into(), std::io::stdout); - init_subscriber(subscriber); + let subscriber = get_subscriber( + "appflowy_server".to_string(), + "info".to_string(), + std::io::stdout, + ); + init_subscriber(subscriber); - let configuration = get_configuration().expect("Failed to read configuration."); - let state = init_state(&configuration).await; - let application = Application::build(configuration, state).await?; - application.run_until_stopped().await?; + let configuration = get_configuration().expect("Failed to read configuration."); + let state = init_state(&configuration).await; + let application = Application::build(configuration, state).await?; + application.run_until_stopped().await?; - Ok(()) + Ok(()) } diff --git a/src/middleware/cors.rs b/src/middleware/cors.rs index 516c16c9..00f7654b 100644 --- a/src/middleware/cors.rs +++ b/src/middleware/cors.rs @@ -6,11 +6,11 @@ use actix_web::http; // http://www.ruanyifeng.com/blog/2016/04/cors.html // Cors short for Cross-Origin Resource Sharing. pub fn default_cors() -> Cors { - Cors::default() // allowed_origin return access-control-allow-origin: * by default - // .allowed_origin("http://127.0.0.1:8080") - .send_wildcard() - .allowed_methods(vec!["GET", "POST", "PUT", "DELETE"]) - .allowed_headers(vec![http::header::ACCEPT]) - .allowed_header(http::header::CONTENT_TYPE) - .max_age(3600) + Cors::default() // allowed_origin return access-control-allow-origin: * by default + .allow_any_origin() + .send_wildcard() + .allowed_methods(vec!["GET", "POST", "PUT", "DELETE"]) + .allowed_headers(vec![http::header::ACCEPT]) + .allowed_header(http::header::CONTENT_TYPE) + .max_age(3600) } diff --git a/src/self_signed.rs b/src/self_signed.rs index dcc6706e..84725f18 100644 --- a/src/self_signed.rs +++ b/src/self_signed.rs @@ -5,26 +5,26 @@ pub const CA_CRT: &str = include_str!("../cert/cert.pem"); pub const CA_KEY: &str = include_str!("../cert/key.pem"); pub fn create_self_signed_certificate() -> Result<(Secret, Secret), RcgenError> { - let key = KeyPair::from_pem(CA_KEY)?; - let params = CertificateParams::from_ca_cert_pem(CA_CRT, key)?; - let ca_cert = Certificate::from_params(params)?; + let key = KeyPair::from_pem(CA_KEY)?; + let params = CertificateParams::from_ca_cert_pem(CA_CRT, key)?; + let ca_cert = Certificate::from_params(params)?; - let mut params = CertificateParams::default(); - params - .subject_alt_names - .push(SanType::IpAddress("127.0.0.1".parse().unwrap())); - params - .subject_alt_names - .push(SanType::IpAddress("0.0.0.0".parse().unwrap())); - params - .subject_alt_names - .push(SanType::DnsName("localhost".to_string())); + let mut params = CertificateParams::default(); + params + .subject_alt_names + .push(SanType::IpAddress("127.0.0.1".parse().unwrap())); + params + .subject_alt_names + .push(SanType::IpAddress("0.0.0.0".parse().unwrap())); + params + .subject_alt_names + .push(SanType::DnsName("localhost".to_string())); - // Generate a certificate that's valid for: - // 1. localhost - // 2. 127.0.0.1 - let gen_cert = Certificate::from_params(params)?; - let server_crt = Secret::new(gen_cert.serialize_pem_with_signer(&ca_cert)?); - let server_key = Secret::new(gen_cert.serialize_private_key_pem()); - Ok((server_crt, server_key)) + // Generate a certificate that's valid for: + // 1. localhost + // 2. 127.0.0.1 + let gen_cert = Certificate::from_params(params)?; + let server_crt = Secret::new(gen_cert.serialize_pem_with_signer(&ca_cert)?); + let server_key = Secret::new(gen_cert.serialize_private_key_pem()); + Ok((server_crt, server_key)) } diff --git a/src/state.rs b/src/state.rs index f2e620c9..16de5560 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,6 +1,8 @@ use crate::component::auth::LoggedUser; use crate::config::config::Config; use chrono::{DateTime, Utc}; +use collab_persistence::kv::rocks_kv::RocksCollabDB; + use snowflake::Snowflake; use sqlx::PgPool; use std::collections::BTreeMap; @@ -9,70 +11,72 @@ use tokio::sync::RwLock; #[derive(Clone)] pub struct State { - pub pg_pool: PgPool, - pub config: Arc, - pub user: Arc>, - pub id_gen: Arc>, + pub pg_pool: PgPool, + pub rocksdb: Arc, + pub config: Arc, + pub user: Arc>, + pub id_gen: Arc>, } impl State { - pub async fn load_users(_pool: &PgPool) { - todo!() - } + pub async fn load_users(_pool: &PgPool) { + todo!() + } - pub async fn next_user_id(&self) -> i64 { - self.id_gen.write().await.next_id() - } + pub async fn next_user_id(&self) -> i64 { + self.id_gen.write().await.next_id() + } } #[derive(Clone, Debug, Copy)] enum AuthStatus { - Authorized(DateTime), - NotAuthorized, + Authorized(DateTime), + NotAuthorized, } pub const EXPIRED_DURATION_DAYS: i64 = 30; #[derive(Debug, Default)] pub struct UserCache { - // Keep track the user authentication state - user: BTreeMap, + // Keep track the user authentication state + user: BTreeMap, } impl UserCache { - pub fn new() -> Self { - UserCache::default() - } + pub fn new() -> Self { + UserCache::default() + } - pub fn is_authorized(&self, user: &LoggedUser) -> bool { - match self.user.get(user.expose_secret()) { - None => { - tracing::debug!("user not login yet or server was reboot"); - false - } - Some(status) => match *status { - AuthStatus::Authorized(last_time) => { - let current_time = Utc::now(); - let days = (current_time - last_time).num_days(); - days < EXPIRED_DURATION_DAYS - } - AuthStatus::NotAuthorized => { - tracing::debug!("user logout already"); - false - } - }, - } + pub fn is_authorized(&self, user: &LoggedUser) -> bool { + match self.user.get(user.expose_secret()) { + None => { + tracing::debug!("user not login yet or server was reboot"); + false + }, + Some(status) => match *status { + AuthStatus::Authorized(last_time) => { + let current_time = Utc::now(); + let days = (current_time - last_time).num_days(); + days < EXPIRED_DURATION_DAYS + }, + AuthStatus::NotAuthorized => { + tracing::debug!("user logout already"); + false + }, + }, } + } - pub fn authorized(&mut self, user: LoggedUser) { - self.user.insert( - user.expose_secret().to_owned(), - AuthStatus::Authorized(Utc::now()), - ); - } + pub fn authorized(&mut self, user: LoggedUser) { + self.user.insert( + user.expose_secret().to_owned(), + AuthStatus::Authorized(Utc::now()), + ); + } - pub fn unauthorized(&mut self, user: LoggedUser) { - self.user - .insert(user.expose_secret().to_owned(), AuthStatus::NotAuthorized); - } + pub fn unauthorized(&mut self, user: LoggedUser) { + self + .user + .insert(user.expose_secret().to_owned(), AuthStatus::NotAuthorized); + } } diff --git a/src/telemetry.rs b/src/telemetry.rs index 076288ce..b0b5772b 100644 --- a/src/telemetry.rs +++ b/src/telemetry.rs @@ -4,39 +4,41 @@ use tracing::Subscriber; use tracing_bunyan_formatter::{BunyanFormattingLayer, JsonStorageLayer}; use tracing_log::LogTracer; use tracing_subscriber::fmt::MakeWriter; -use tracing_subscriber::{layer::SubscriberExt, EnvFilter, Registry}; +use tracing_subscriber::{layer::SubscriberExt, EnvFilter}; /// Compose multiple layers into a `tracing`'s subscriber. pub fn get_subscriber( - name: String, - env_filter: String, - sink: Sink, + name: String, + env_filter: String, + sink: Sink, ) -> impl Subscriber + Sync + Send where - Sink: for<'a> MakeWriter<'a> + Send + Sync + 'static, + Sink: for<'a> MakeWriter<'a> + Send + Sync + 'static, { - let env_filter = - EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(env_filter)); - let formatting_layer = BunyanFormattingLayer::new(name, sink); - Registry::default() - .with(env_filter) - .with(JsonStorageLayer) - .with(formatting_layer) + let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(env_filter)); + // let env_filter = EnvFilter::new(env_filter); + let formatting_layer = BunyanFormattingLayer::new(name, sink); + tracing_subscriber::fmt() + .with_ansi(true) + .finish() + .with(env_filter) + .with(JsonStorageLayer) + .with(formatting_layer) } /// Register a subscriber as global default to process span data. /// /// It should only be called once! pub fn init_subscriber(subscriber: impl Subscriber + Sync + Send) { - LogTracer::init().expect("Failed to set logger"); - set_global_default(subscriber).expect("Failed to set subscriber"); + LogTracer::init().expect("Failed to set logger"); + set_global_default(subscriber).expect("Failed to set subscriber"); } pub fn spawn_blocking_with_tracing(f: F) -> JoinHandle where - F: FnOnce() -> R + Send + 'static, - R: Send + 'static, + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, { - let current_span = tracing::Span::current(); - actix_web::rt::task::spawn_blocking(move || current_span.in_scope(f)) + let current_span = tracing::Span::current(); + actix_web::rt::task::spawn_blocking(move || current_span.in_scope(f)) } diff --git a/tests/api/login.rs b/tests/api/login.rs index 8e6f9c9f..a354946e 100644 --- a/tests/api/login.rs +++ b/tests/api/login.rs @@ -1,44 +1,44 @@ -use crate::test_server::{spawn_server, TestUser}; +use crate::util::{spawn_server, TestUser}; use actix_web::http::StatusCode; use appflowy_server::component::auth::LoginResponse; -#[tokio::test] +#[actix_rt::test] async fn login_success() { - let server = spawn_server().await; - let test_user = TestUser::generate(); - test_user.register(&server).await; + let server = spawn_server().await; + let test_user = TestUser::generate(); + test_user.register(&server).await; - let http_resp = server.login(&test_user.email, &test_user.password).await; - assert_eq!(http_resp.status(), StatusCode::OK); + let http_resp = server.login(&test_user.email, &test_user.password).await; + assert_eq!(http_resp.status(), StatusCode::OK); - let bytes = http_resp.bytes().await.unwrap(); - let response: LoginResponse = serde_json::from_slice(&bytes).unwrap(); - assert!(!response.token.is_empty()) + let bytes = http_resp.bytes().await.unwrap(); + let response: LoginResponse = serde_json::from_slice(&bytes).unwrap(); + assert!(!response.token.is_empty()) } -#[tokio::test] +#[actix_rt::test] async fn login_with_empty_email() { - let server = spawn_server().await; - let test_user = TestUser::generate(); - test_user.register(&server).await; + let server = spawn_server().await; + let test_user = TestUser::generate(); + test_user.register(&server).await; - let http_resp = server.login("", &test_user.password).await; - assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST); + let http_resp = server.login("", &test_user.password).await; + assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST); } -#[tokio::test] +#[actix_rt::test] async fn login_with_empty_password() { - let server = spawn_server().await; - let test_user = TestUser::generate(); - test_user.register(&server).await; + let server = spawn_server().await; + let test_user = TestUser::generate(); + test_user.register(&server).await; - let http_resp = server.login(&test_user.email, "").await; - assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST); + let http_resp = server.login(&test_user.email, "").await; + assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST); } -#[tokio::test] +#[actix_rt::test] async fn login_with_unknown_user() { - let server = spawn_server().await; - let http_resp = server.login("unknown@appflowy.io", "Abc@123!").await; - assert_eq!(http_resp.status(), StatusCode::UNAUTHORIZED); + let server = spawn_server().await; + let http_resp = server.login("unknown@appflowy.io", "Abc@123!").await; + assert_eq!(http_resp.status(), StatusCode::UNAUTHORIZED); } diff --git a/tests/api/main.rs b/tests/api/mod.rs similarity index 69% rename from tests/api/main.rs rename to tests/api/mod.rs index d80e94ba..05dfb27f 100644 --- a/tests/api/main.rs +++ b/tests/api/mod.rs @@ -1,4 +1,4 @@ mod login; mod password; mod register; -mod test_server; +mod ws; diff --git a/tests/api/password.rs b/tests/api/password.rs index 957c9e4e..a895fd64 100644 --- a/tests/api/password.rs +++ b/tests/api/password.rs @@ -1,53 +1,53 @@ -use crate::test_server::{spawn_server, TestUser}; +use crate::util::{spawn_server, TestUser}; use actix_web::http::StatusCode; -#[tokio::test] +#[actix_rt::test] async fn change_password_with_unmatched_password() { - let server = spawn_server().await; - let test_user = TestUser::generate(); - let token = test_user.register(&server).await; + let server = spawn_server().await; + let test_user = TestUser::generate(); + let token = test_user.register(&server).await; - let new_password = "HelloWorld@1a"; - let new_password_confirm = "HeloWorld@1a"; - let http_resp = server - .change_password( - token, - &test_user.password, - new_password, - new_password_confirm, - ) - .await; - assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST); + let new_password = "HelloWorld@1a"; + let new_password_confirm = "HeloWorld@1a"; + let http_resp = server + .change_password( + token, + &test_user.password, + new_password, + new_password_confirm, + ) + .await; + assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST); } -#[tokio::test] +#[actix_rt::test] async fn login_fail_after_change_password() { - let server = spawn_server().await; - let test_user = TestUser::generate(); - let token = test_user.register(&server).await; + let server = spawn_server().await; + let test_user = TestUser::generate(); + let token = test_user.register(&server).await; - let new_password = "HelloWorld@1a"; - let http_resp = server - .change_password(token, &test_user.password, new_password, new_password) - .await; - assert_eq!(http_resp.status(), StatusCode::OK); + let new_password = "HelloWorld@1a"; + let http_resp = server + .change_password(token, &test_user.password, new_password, new_password) + .await; + assert_eq!(http_resp.status(), StatusCode::OK); - let http_resp = server.login(&test_user.email, &test_user.password).await; - assert_eq!(http_resp.status(), StatusCode::UNAUTHORIZED); + let http_resp = server.login(&test_user.email, &test_user.password).await; + assert_eq!(http_resp.status(), StatusCode::UNAUTHORIZED); } -#[tokio::test] +#[actix_rt::test] async fn login_success_with_new_password() { - let server = spawn_server().await; - let test_user = TestUser::generate(); - let token = test_user.register(&server).await; + let server = spawn_server().await; + let test_user = TestUser::generate(); + let token = test_user.register(&server).await; - let new_password = "HelloWorld@1a"; - let http_resp = server - .change_password(token, &test_user.password, new_password, new_password) - .await; - assert_eq!(http_resp.status(), StatusCode::OK); + let new_password = "HelloWorld@1a"; + let http_resp = server + .change_password(token, &test_user.password, new_password, new_password) + .await; + assert_eq!(http_resp.status(), StatusCode::OK); - let http_resp = server.login(&test_user.email, new_password).await; - assert_eq!(http_resp.status(), StatusCode::OK); + let http_resp = server.login(&test_user.email, new_password).await; + assert_eq!(http_resp.status(), StatusCode::OK); } diff --git a/tests/api/register.rs b/tests/api/register.rs index 1fd1864a..d81b76ba 100644 --- a/tests/api/register.rs +++ b/tests/api/register.rs @@ -1,53 +1,53 @@ -use crate::test_server::{error_msg_from_resp, spawn_server}; +use crate::util::{error_msg_from_resp, spawn_server}; use appflowy_server::component::auth::{InputParamsError, RegisterResponse}; use reqwest::StatusCode; -#[tokio::test] +#[actix_rt::test] // curl -X POST --url http://0.0.0.0:8000/api/user/register --header 'content-type: application/json' --data '{"name":"fake name", "email":"fake@appflowy.io", "password":"Fake@123"}' async fn register_success() { - let server = spawn_server().await; - let http_resp = server - .register("user 1", "fake@appflowy.io", "FakePassword!123") - .await; + let server = spawn_server().await; + let http_resp = server + .register("user 1", "fake@appflowy.io", "FakePassword!123") + .await; - let bytes = http_resp.bytes().await.unwrap(); - let response: RegisterResponse = serde_json::from_slice(&bytes).unwrap(); - println!("{:?}", response); + let bytes = http_resp.bytes().await.unwrap(); + let response: RegisterResponse = serde_json::from_slice(&bytes).unwrap(); + println!("{:?}", response); } -#[tokio::test] +#[actix_rt::test] async fn register_with_invalid_password() { - let server = spawn_server().await; - let http_resp = server.register("user 1", "fake@appflowy.io", "123").await; - assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST); - assert_eq!( - error_msg_from_resp(http_resp).await, - InputParamsError::InvalidPassword.to_string() - ); + let server = spawn_server().await; + let http_resp = server.register("user 1", "fake@appflowy.io", "123").await; + assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST); + assert_eq!( + error_msg_from_resp(http_resp).await, + InputParamsError::InvalidPassword.to_string() + ); } -#[tokio::test] +#[actix_rt::test] async fn register_with_invalid_name() { - let server = spawn_server().await; - let name = "".to_string(); - let http_resp = server - .register(&name, "fake@appflowy.io", "FakePassword!123") - .await; - assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST); - assert_eq!( - error_msg_from_resp(http_resp).await, - InputParamsError::InvalidName(name).to_string() - ); + let server = spawn_server().await; + let name = "".to_string(); + let http_resp = server + .register(&name, "fake@appflowy.io", "FakePassword!123") + .await; + assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST); + assert_eq!( + error_msg_from_resp(http_resp).await, + InputParamsError::InvalidName(name).to_string() + ); } -#[tokio::test] +#[actix_rt::test] async fn register_with_invalid_email() { - let server = spawn_server().await; - let email = "appflowy.io".to_string(); - let http_resp = server.register("me", &email, "FakePassword!123").await; - assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST); - assert_eq!( - error_msg_from_resp(http_resp).await, - InputParamsError::InvalidEmail(email).to_string() - ); + let server = spawn_server().await; + let email = "appflowy.io".to_string(); + let http_resp = server.register("me", &email, "FakePassword!123").await; + assert_eq!(http_resp.status(), StatusCode::BAD_REQUEST); + assert_eq!( + error_msg_from_resp(http_resp).await, + InputParamsError::InvalidEmail(email).to_string() + ); } diff --git a/tests/api/test_server.rs b/tests/api/test_server.rs deleted file mode 100644 index e45c6b6e..00000000 --- a/tests/api/test_server.rs +++ /dev/null @@ -1,189 +0,0 @@ -use appflowy_server::application::{init_state, Application}; -use appflowy_server::config::config::{get_configuration, DatabaseSetting, TlsConfig}; -use appflowy_server::state::State; -use appflowy_server::telemetry::{get_subscriber, init_subscriber}; -use once_cell::sync::Lazy; -use reqwest::Certificate; - -use appflowy_server::component::auth::{RegisterResponse, HEADER_TOKEN}; -use sqlx::types::Uuid; -use sqlx::{Connection, Executor, PgConnection, PgPool}; - -// Ensure that the `tracing` stack is only initialised once using `once_cell` -static TRACING: Lazy<()> = Lazy::new(|| { - let level = "info".to_string(); - let mut filters = vec![]; - filters.push(format!("appflowy_server={}", level)); - filters.push(format!("hyper={}", level)); - - let subscriber_name = "test".to_string(); - let subscriber = get_subscriber(subscriber_name, filters.join(","), std::io::stdout); - init_subscriber(subscriber); -}); - -#[derive(Clone)] -pub struct TestServer { - pub state: State, - pub api_client: reqwest::Client, - pub address: String, - pub port: u16, -} - -impl TestServer { - pub async fn register(&self, name: &str, email: &str, password: &str) -> reqwest::Response { - let payload = serde_json::json!({ - "name": name, - "password": password, - "email": email - }); - let url = format!("{}/api/user/register", self.address); - self.api_client - .post(&url) - .json(&payload) - .send() - .await - .expect("Register failed") - } - - pub async fn login(&self, email: &str, password: &str) -> reqwest::Response { - let payload = serde_json::json!({ - "password": password, - "email": email - }); - let url = format!("{}/api/user/login", self.address); - self.api_client - .post(&url) - .json(&payload) - .send() - .await - .expect("Login failed") - } - - pub async fn change_password( - &self, - token: String, - current_password: &str, - new_password: &str, - new_password_confirm: &str, - ) -> reqwest::Response { - let payload = serde_json::json!({ - "current_password": current_password, - "new_password": new_password, - "new_password_confirm": new_password_confirm - }); - let url = format!("{}/api/user/password", self.address); - self.api_client - .post(&url) - .header(HEADER_TOKEN, token) - .json(&payload) - .send() - .await - .expect("Change password failed") - } -} - -pub async fn spawn_server() -> TestServer { - Lazy::force(&TRACING); - let database_name = Uuid::new_v4().to_string(); - let config = { - let mut config = get_configuration().expect("Failed to read configuration."); - config.database.database_name = database_name.clone(); - // Use a random OS port - config.application.port = 0; - config - }; - - let _ = configure_database(&config.database).await; - let state = init_state(&config).await; - let application = Application::build(config.clone(), state.clone()) - .await - .expect("Failed to build application"); - - let port = application.port(); - let _ = tokio::spawn(async { - let _ = application.run_until_stopped().await; - }); - let mut builder = reqwest::Client::builder(); - let mut address = format!("http://localhost:{}", port); - if config.application.use_https() { - address = format!("https://localhost:{}", port); - builder = builder.add_root_certificate( - Certificate::from_pem(include_bytes!("../../cert/cert.pem")).unwrap(), - ); - } - let api_client = builder - .add_root_certificate(Certificate::from_pem(include_bytes!("../../cert/cert.pem")).unwrap()) - .redirect(reqwest::redirect::Policy::none()) - .danger_accept_invalid_certs(true) - .cookie_store(true) - .no_proxy() - .build() - .unwrap(); - - TestServer { - state, - api_client, - address, - port, - } -} - -async fn configure_database(config: &DatabaseSetting) -> PgPool { - // Create database - let mut connection = PgConnection::connect_with(&config.without_db()) - .await - .expect("Failed to connect to Postgres"); - connection - .execute(&*format!(r#"CREATE DATABASE "{}";"#, config.database_name)) - .await - .expect("Failed to create database."); - - // Migrate database - let connection_pool = PgPool::connect_with(config.with_db()) - .await - .expect("Failed to connect to Postgres."); - - sqlx::migrate!("./migrations") - .run(&connection_pool) - .await - .expect("Failed to migrate the database"); - - connection_pool -} - -#[derive(serde::Serialize)] -pub struct TestUser { - name: String, - pub email: String, - pub password: String, -} - -impl TestUser { - pub fn generate() -> Self { - Self { - name: "Me".to_string(), - email: "me@appflowy.io".to_string(), - password: "Hello@AppFlowy123".to_string(), - } - } - - pub async fn register(&self, test_server: &TestServer) -> String { - let url = format!("{}/api/user/register", test_server.address); - let resp = test_server - .api_client - .post(&url) - .json(self) - .send() - .await - .expect("Fail to register user"); - - let bytes = resp.bytes().await.unwrap(); - let response: RegisterResponse = serde_json::from_slice(&bytes).unwrap(); - response.token - } -} - -pub async fn error_msg_from_resp(resp: reqwest::Response) -> String { - let bytes = resp.bytes().await.unwrap(); - String::from_utf8(bytes.to_vec()).unwrap() -} diff --git a/tests/api/ws.rs b/tests/api/ws.rs new file mode 100644 index 00000000..aa823df7 --- /dev/null +++ b/tests/api/ws.rs @@ -0,0 +1,13 @@ +use crate::util::{spawn_server, TestUser}; +use collab_client_ws::WSClient; + +#[actix_rt::test] +async fn ws_conn_test() { + let server = spawn_server().await; + let test_user = TestUser::generate(); + let token = test_user.register(&server).await; + + let address = format!("{}/{}", server.ws_addr, token); + let client = WSClient::new(address, 100); + let _ = client.connect().await; +} diff --git a/tests/main.rs b/tests/main.rs new file mode 100644 index 00000000..0860deec --- /dev/null +++ b/tests/main.rs @@ -0,0 +1,3 @@ +mod api; +mod util; +mod ws; diff --git a/tests/util/mod.rs b/tests/util/mod.rs new file mode 100644 index 00000000..1efb8ecd --- /dev/null +++ b/tests/util/mod.rs @@ -0,0 +1,2 @@ +mod test_server; +pub use test_server::*; diff --git a/tests/util/test_server.rs b/tests/util/test_server.rs new file mode 100644 index 00000000..78bb6ad6 --- /dev/null +++ b/tests/util/test_server.rs @@ -0,0 +1,232 @@ +use appflowy_server::application::{init_state, Application}; +use appflowy_server::config::config::{get_configuration, DatabaseSetting}; +use appflowy_server::state::State; +use appflowy_server::telemetry::{get_subscriber, init_subscriber}; +use once_cell::sync::Lazy; +use reqwest::Certificate; +use std::path::PathBuf; + +use appflowy_server::component::auth::{RegisterResponse, HEADER_TOKEN}; +use sqlx::types::Uuid; +use sqlx::{Connection, Executor, PgConnection, PgPool}; + +// Ensure that the `tracing` stack is only initialised once using `once_cell` +static TRACING: Lazy<()> = Lazy::new(|| { + let level = "trace".to_string(); + let mut filters = vec![]; + filters.push(format!("appflowy_server={}", level)); + filters.push(format!("collab_client_ws={}", level)); + filters.push(format!("hyper={}", level)); + filters.push(format!("actix_web={}", level)); + + let subscriber_name = "test".to_string(); + let subscriber = get_subscriber(subscriber_name, filters.join(","), std::io::stdout); + init_subscriber(subscriber); +}); + +#[derive(Clone)] +pub struct TestServer { + pub state: State, + pub api_client: reqwest::Client, + pub address: String, + pub port: u16, + pub ws_addr: String, + #[allow(dead_code)] + pub cleaner: Cleaner, +} + +impl TestServer { + pub async fn register(&self, name: &str, email: &str, password: &str) -> reqwest::Response { + let payload = serde_json::json!({ + "name": name, + "password": password, + "email": email + }); + let url = format!("{}/api/user/register", self.address); + self + .api_client + .post(&url) + .json(&payload) + .send() + .await + .expect("Register failed") + } + + pub async fn login(&self, email: &str, password: &str) -> reqwest::Response { + let payload = serde_json::json!({ + "password": password, + "email": email + }); + let url = format!("{}/api/user/login", self.address); + self + .api_client + .post(&url) + .json(&payload) + .send() + .await + .expect("Login failed") + } + + pub async fn change_password( + &self, + token: String, + current_password: &str, + new_password: &str, + new_password_confirm: &str, + ) -> reqwest::Response { + let payload = serde_json::json!({ + "current_password": current_password, + "new_password": new_password, + "new_password_confirm": new_password_confirm + }); + let url = format!("{}/api/user/password", self.address); + self + .api_client + .post(&url) + .header(HEADER_TOKEN, token) + .json(&payload) + .send() + .await + .expect("Change password failed") + } +} + +pub async fn spawn_server() -> TestServer { + Lazy::force(&TRACING); + let database_name = Uuid::new_v4().to_string(); + let config = { + let mut config = get_configuration().expect("Failed to read configuration."); + config.database.database_name = database_name.clone(); + // Use a random OS port + config.application.port = 0; + config.application.data_dir = PathBuf::from(format!("./data/{}", database_name)); + config + }; + + let _ = configure_database(&config.database).await; + let state = init_state(&config).await; + let application = Application::build(config.clone(), state.clone()) + .await + .expect("Failed to build application"); + + let port = application.port(); + let _ = tokio::spawn(async { + let _ = application.run_until_stopped().await; + }); + let mut builder = reqwest::Client::builder(); + let mut address = format!("http://localhost:{}", port); + let mut ws_addr = format!("ws://localhost:{}/ws", port); + if config.application.use_https() { + address = format!("https://localhost:{}", port); + ws_addr = format!("wss://localhost:{}/ws", port); + builder = builder + .add_root_certificate(Certificate::from_pem(include_bytes!("../../cert/cert.pem")).unwrap()); + } + + let api_client = builder + .add_root_certificate(Certificate::from_pem(include_bytes!("../../cert/cert.pem")).unwrap()) + .redirect(reqwest::redirect::Policy::none()) + .danger_accept_invalid_certs(true) + .cookie_store(true) + .no_proxy() + .build() + .unwrap(); + + let cleaner = Cleaner::new(config.application.data_dir); + + TestServer { + state, + api_client, + address, + ws_addr, + port, + cleaner, + } +} + +async fn configure_database(config: &DatabaseSetting) -> PgPool { + // Create database + let mut connection = PgConnection::connect_with(&config.without_db()) + .await + .expect("Failed to connect to Postgres"); + connection + .execute(&*format!(r#"CREATE DATABASE "{}";"#, config.database_name)) + .await + .expect("Failed to create database."); + + // Migrate database + let connection_pool = PgPool::connect_with(config.with_db()) + .await + .expect("Failed to connect to Postgres."); + + sqlx::migrate!("./migrations") + .run(&connection_pool) + .await + .expect("Failed to migrate the database"); + + connection_pool +} + +#[derive(serde::Serialize)] +pub struct TestUser { + name: String, + pub email: String, + pub password: String, +} + +impl TestUser { + pub fn generate() -> Self { + Self { + name: "Me".to_string(), + email: "me@appflowy.io".to_string(), + password: "Hello@AppFlowy123".to_string(), + } + } + + pub async fn register(&self, test_server: &TestServer) -> String { + let url = format!("{}/api/user/register", test_server.address); + let resp = test_server + .api_client + .post(&url) + .json(self) + .send() + .await + .expect("Fail to register user"); + + let bytes = resp.bytes().await.unwrap(); + let response: RegisterResponse = serde_json::from_slice(&bytes).unwrap(); + response.token + } +} + +pub async fn error_msg_from_resp(resp: reqwest::Response) -> String { + let bytes = resp.bytes().await.unwrap(); + String::from_utf8(bytes.to_vec()).unwrap() +} + +#[derive(Clone)] +pub struct Cleaner { + path: PathBuf, + should_clean: bool, +} + +impl Cleaner { + fn new(path: PathBuf) -> Self { + Self { + path, + should_clean: true, + } + } + + fn cleanup(dir: &PathBuf) { + let _ = std::fs::remove_dir_all(dir); + } +} + +impl Drop for Cleaner { + fn drop(&mut self) { + if self.should_clean { + Self::cleanup(&self.path) + } + } +} diff --git a/tests/ws/mod.rs b/tests/ws/mod.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/tests/ws/mod.rs @@ -0,0 +1 @@ +