From 2c1c820e68dce74a19821b0fe4a75fc41a647182 Mon Sep 17 00:00:00 2001 From: "Nathan.fooo" <86001920+appflowy@users.noreply.github.com> Date: Sun, 6 Apr 2025 19:11:40 +0800 Subject: [PATCH] chore: support azure open ai (#1321) --- Cargo.lock | 419 ++++++++++++++---- Cargo.toml | 14 +- libs/appflowy-ai-client/src/dto.rs | 65 +-- libs/indexer/Cargo.toml | 1 + .../src/collab_indexer/document_indexer.rs | 40 +- libs/indexer/src/collab_indexer/provider.rs | 4 +- libs/indexer/src/scheduler.rs | 62 ++- libs/indexer/src/unindexed_workspace.rs | 4 +- libs/indexer/src/vector/embedder.rs | 45 +- libs/indexer/src/vector/mod.rs | 1 - libs/indexer/src/vector/open_ai.rs | 116 ++--- libs/indexer/src/vector/rest.rs | 189 -------- .../tests/indexer_test.rs | 4 +- services/appflowy-worker/src/application.rs | 26 +- .../src/indexer_worker/worker.rs | 101 +++-- src/application.rs | 8 +- src/biz/search/ops.rs | 38 +- 17 files changed, 549 insertions(+), 588 deletions(-) delete mode 100644 libs/indexer/src/vector/rest.rs diff --git a/Cargo.lock b/Cargo.lock index 59b96ffc..4437c966 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -399,7 +399,7 @@ dependencies = [ "sha2", "shared-entity", "tokio", - "tower", + "tower 0.4.13", "tower-http", "tower-service", "tracing", @@ -648,7 +648,7 @@ dependencies = [ "reqwest", "sanitize-filename", "scraper", - "secrecy", + "secrecy 0.8.0", "semver", "serde", "serde_json", @@ -712,7 +712,7 @@ dependencies = [ "rand 0.8.5", "rayon", "redis 0.25.4", - "secrecy", + "secrecy 0.8.0", "semver", "serde", "serde_json", @@ -760,7 +760,7 @@ dependencies = [ "prometheus-client", "rayon", "redis 0.25.4", - "secrecy", + "secrecy 0.8.0", "serde", "serde_json", "sqlx", @@ -880,6 +880,43 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "async-openai" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36c566b15aa847e60a9e6c9b9b4b9d4be94bbf776804624279afa69559fea7e1" +dependencies = [ + "async-openai-macros", + "backoff", + "base64 0.22.1", + "bytes", + "derive_builder", + "eventsource-stream", + "futures", + "rand 0.8.5", + "reqwest", + "reqwest-eventsource", + "secrecy 0.10.3", + "serde", + "serde_json", + "thiserror 2.0.12", + "tokio", + "tokio-stream", + "tokio-util", + "tracing", +] + +[[package]] +name = "async-openai-macros" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0289cba6d5143bfe8251d57b4a8cac036adf158525a76533a7082ba65ec76398" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.90", +] + [[package]] name = "async-recursion" version = "1.1.1" @@ -965,7 +1002,7 @@ dependencies = [ "argon2", "gotrue-entity", "rand 0.8.5", - "secrecy", + "secrecy 0.8.0", "serde", "sqlx", "thiserror 1.0.63", @@ -1273,7 +1310,7 @@ dependencies = [ "serde_urlencoded", "sync_wrapper 1.0.1", "tokio", - "tower", + "tower 0.4.13", "tower-layer", "tower-service", "tracing", @@ -1317,12 +1354,26 @@ dependencies = [ "mime", "pin-project-lite", "serde", - "tower", + "tower 0.4.13", "tower-layer", "tower-service", "tracing", ] +[[package]] +name = "backoff" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1" +dependencies = [ + "futures-core", + "getrandom 0.2.15", + "instant", + "pin-project-lite", + "rand 0.8.5", + "tokio", +] + [[package]] name = "backtrace" version = "0.3.73" @@ -1535,7 +1586,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "40723b8fb387abc38f4f4a37c09073622e41dd12327033091ef8950659e6dc0c" dependencies = [ "memchr", - "regex-automata 0.4.7", + "regex-automata 0.4.9", "serde", ] @@ -1888,7 +1939,7 @@ dependencies = [ [[package]] name = "collab" version = "0.2.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b#3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=4a0e2cc07f50f17d1b6605c579e622a431e94998#4a0e2cc07f50f17d1b6605c579e622a431e94998" dependencies = [ "anyhow", "arc-swap", @@ -1913,7 +1964,7 @@ dependencies = [ [[package]] name = "collab-database" version = "0.2.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b#3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=4a0e2cc07f50f17d1b6605c579e622a431e94998#4a0e2cc07f50f17d1b6605c579e622a431e94998" dependencies = [ "anyhow", "async-trait", @@ -1953,7 +2004,7 @@ dependencies = [ [[package]] name = "collab-document" version = "0.2.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b#3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=4a0e2cc07f50f17d1b6605c579e622a431e94998#4a0e2cc07f50f17d1b6605c579e622a431e94998" dependencies = [ "anyhow", "arc-swap", @@ -1974,7 +2025,7 @@ dependencies = [ [[package]] name = "collab-entity" version = "0.2.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b#3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=4a0e2cc07f50f17d1b6605c579e622a431e94998#4a0e2cc07f50f17d1b6605c579e622a431e94998" dependencies = [ "anyhow", "bytes", @@ -1994,7 +2045,7 @@ dependencies = [ [[package]] name = "collab-folder" version = "0.2.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b#3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=4a0e2cc07f50f17d1b6605c579e622a431e94998#4a0e2cc07f50f17d1b6605c579e622a431e94998" dependencies = [ "anyhow", "arc-swap", @@ -2016,7 +2067,7 @@ dependencies = [ [[package]] name = "collab-importer" version = "0.1.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b#3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=4a0e2cc07f50f17d1b6605c579e622a431e94998#4a0e2cc07f50f17d1b6605c579e622a431e94998" dependencies = [ "anyhow", "async-recursion", @@ -2124,7 +2175,7 @@ dependencies = [ [[package]] name = "collab-user" version = "0.2.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b#3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=4a0e2cc07f50f17d1b6605c579e622a431e94998#4a0e2cc07f50f17d1b6605c579e622a431e94998" dependencies = [ "anyhow", "collab", @@ -2652,6 +2703,37 @@ dependencies = [ "powerfmt", ] +[[package]] +name = "derive_builder" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.90", +] + +[[package]] +name = "derive_builder_macro" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" +dependencies = [ + "derive_builder_core", + "syn 2.0.90", +] + [[package]] name = "derive_more" version = "0.99.18" @@ -2847,6 +2929,17 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "eventsource-stream" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab" +dependencies = [ + "futures-core", + "nom", + "pin-project-lite", +] + [[package]] name = "fancy-regex" version = "0.11.0" @@ -2864,8 +2957,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2" dependencies = [ "bit-set", - "regex-automata 0.4.7", - "regex-syntax 0.8.4", + "regex-automata 0.4.9", + "regex-syntax 0.8.5", ] [[package]] @@ -3552,7 +3645,7 @@ dependencies = [ "hyper 0.14.30", "log", "rustls 0.21.12", - "rustls-native-certs", + "rustls-native-certs 0.6.3", "tokio", "tokio-rustls 0.24.1", ] @@ -3568,6 +3661,7 @@ dependencies = [ "hyper 1.4.1", "hyper-util", "rustls 0.23.20", + "rustls-native-certs 0.7.3", "rustls-pki-types", "tokio", "tokio-rustls 0.26.0", @@ -3823,6 +3917,7 @@ dependencies = [ "anyhow", "app-error", "appflowy-ai-client", + "async-openai", "async-trait", "chrono", "collab", @@ -3836,7 +3931,7 @@ dependencies = [ "rayon", "redis 0.25.4", "reqwest", - "secrecy", + "secrecy 0.8.0", "serde", "serde_json", "sqlx", @@ -3927,6 +4022,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" @@ -3953,10 +4057,11 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.69" +version = "0.3.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f" dependencies = [ + "once_cell", "wasm-bindgen", ] @@ -4034,9 +4139,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.155" +version = "0.2.171" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +checksum = "c19937216e9d3aa9956d9bb8dfc0b0c8beb6058fc4f7a4dc4d850edf86a237d6" [[package]] name = "libm" @@ -4149,7 +4254,7 @@ dependencies = [ "anyhow", "handlebars", "lettre", - "secrecy", + "secrecy 0.8.0", "serde", ] @@ -5116,7 +5221,7 @@ checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15" dependencies = [ "bytes", "heck 0.5.0", - "itertools 0.12.1", + "itertools 0.13.0", "log", "multimap", "once_cell", @@ -5149,7 +5254,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e9552f850d5f0964a4e4d0bf306459ac29323ddfbae05e35a7c0d35cb0803cc5" dependencies = [ "anyhow", - "itertools 0.12.1", + "itertools 0.13.0", "proc-macro2", "quote", "syn 2.0.90", @@ -5540,14 +5645,14 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.5" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.7", - "regex-syntax 0.8.4", + "regex-automata 0.4.9", + "regex-syntax 0.8.5", ] [[package]] @@ -5561,13 +5666,13 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.7" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.4", + "regex-syntax 0.8.5", ] [[package]] @@ -5584,9 +5689,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "regex-syntax" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "rend" @@ -5599,9 +5704,9 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.12.9" +version = "0.12.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a77c62af46e79de0a562e1a9849205ffcb7fc1238876e9bd743357570e04046f" +checksum = "d19c46a6fdd48bc4dab94b6103fccc55d34c67cc0ad04653aad4ea2a07cd7bbb" dependencies = [ "base64 0.22.1", "bytes", @@ -5629,6 +5734,7 @@ dependencies = [ "pin-project-lite", "quinn", "rustls 0.23.20", + "rustls-native-certs 0.8.0", "rustls-pemfile 2.1.2", "rustls-pki-types", "serde", @@ -5640,6 +5746,7 @@ dependencies = [ "tokio-native-tls", "tokio-rustls 0.26.0", "tokio-util", + "tower 0.5.2", "tower-service", "url", "wasm-bindgen", @@ -5650,6 +5757,22 @@ dependencies = [ "windows-registry", ] +[[package]] +name = "reqwest-eventsource" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde" +dependencies = [ + "eventsource-stream", + "futures-core", + "futures-timer", + "mime", + "nom", + "pin-project-lite", + "reqwest", + "thiserror 1.0.63", +] + [[package]] name = "rfc6979" version = "0.3.1" @@ -5868,6 +5991,32 @@ dependencies = [ "security-framework", ] +[[package]] +name = "rustls-native-certs" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5bfb394eeed242e909609f56089eecfe5fda225042e8b171791b9c95f5931e5" +dependencies = [ + "openssl-probe", + "rustls-pemfile 2.1.2", + "rustls-pki-types", + "schannel", + "security-framework", +] + +[[package]] +name = "rustls-native-certs" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcaf18a4f2be7326cd874a5fa579fae794320a0f388d365dca7e480e55f83f8a" +dependencies = [ + "openssl-probe", + "rustls-pemfile 2.1.2", + "rustls-pki-types", + "schannel", + "security-framework", +] + [[package]] name = "rustls-pemfile" version = "1.0.4" @@ -6039,6 +6188,16 @@ dependencies = [ "zeroize", ] +[[package]] +name = "secrecy" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a" +dependencies = [ + "serde", + "zeroize", +] + [[package]] name = "security-framework" version = "2.11.1" @@ -6092,18 +6251,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.204" +version = "1.0.219" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc76f558e0cbb2a839d37354c575f1dc3fdc6546b5be373ba43d95f231bf7c12" +checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.204" +version = "1.0.219" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222" +checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" dependencies = [ "proc-macro2", "quote", @@ -6123,9 +6282,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.128" +version = "1.0.140" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" +checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" dependencies = [ "itoa", "memchr", @@ -6858,11 +7017,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.9" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f072643fd0190df67a8bab670c20ef5d8737177d6ac6b2e9a236cb096206b2cc" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" dependencies = [ - "thiserror-impl 2.0.9", + "thiserror-impl 2.0.12", ] [[package]] @@ -6878,9 +7037,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.9" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b50fa271071aae2e6ee85f842e2e28ba8cd2c5fb67f11fcb1fd70b276f9e7d4" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" dependencies = [ "proc-macro2", "quote", @@ -6991,9 +7150,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.39.2" +version = "1.44.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daa4fb1bc778bd6f04cbfc4bb2d06a7396a8f299dc33ea1900cedaa316f467b1" +checksum = "e6b88822cbe49de4185e3a4cbf8321dd487cf5fe0c5c65695fef6346371e9c48" dependencies = [ "backtrace", "bytes", @@ -7010,9 +7169,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.4.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" +checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", @@ -7075,9 +7234,9 @@ dependencies = [ [[package]] name = "tokio-stream" -version = "0.1.16" +version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f4e6ce100d0eb49a2734f8c0812bcd324cf357d21810932c5df6b96ef2b86f1" +checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" dependencies = [ "futures-core", "pin-project-lite", @@ -7095,7 +7254,7 @@ dependencies = [ "log", "native-tls", "rustls 0.21.12", - "rustls-native-certs", + "rustls-native-certs 0.6.3", "tokio", "tokio-native-tls", "tokio-rustls 0.24.1", @@ -7119,9 +7278,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.12" +version = "0.7.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" +checksum = "6b9590b93e6fcc1739458317cccd391ad3955e2bde8913edf6f95f9e65a8f034" dependencies = [ "bytes", "futures-core", @@ -7172,7 +7331,7 @@ dependencies = [ "socket2", "tokio", "tokio-stream", - "tower", + "tower 0.4.13", "tower-layer", "tower-service", "tracing", @@ -7221,6 +7380,21 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper 1.0.1", + "tokio", + "tower-layer", + "tower-service", +] + [[package]] name = "tower-http" version = "0.5.2" @@ -7248,15 +7422,15 @@ dependencies = [ [[package]] name = "tower-layer" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" [[package]] name = "tower-service" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" @@ -7406,7 +7580,7 @@ dependencies = [ "native-tls", "rand 0.8.5", "sha1", - "thiserror 2.0.9", + "thiserror 2.0.12", "utf-8", ] @@ -7670,23 +7844,24 @@ checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" [[package]] name = "wasm-bindgen" -version = "0.2.92" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5" dependencies = [ "cfg-if", + "once_cell", + "rustversion", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.92" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6" dependencies = [ "bumpalo", "log", - "once_cell", "proc-macro2", "quote", "syn 2.0.90", @@ -7707,9 +7882,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.92" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -7717,9 +7892,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.92" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", @@ -7730,9 +7905,12 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.92" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" +checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d" +dependencies = [ + "unicode-ident", +] [[package]] name = "wasm-bindgen-test" @@ -7890,33 +8068,38 @@ dependencies = [ ] [[package]] -name = "windows-registry" -version = "0.2.0" +name = "windows-link" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0" +checksum = "76840935b766e1b0a05c0066835fb9ec80071d4c09a16f6bd5f7e655e3c14c38" + +[[package]] +name = "windows-registry" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4286ad90ddb45071efd1a66dfa43eb02dd0dfbae1545ad6cc3c51cf34d7e8ba3" dependencies = [ "windows-result", "windows-strings", - "windows-targets 0.52.6", + "windows-targets 0.53.0", ] [[package]] name = "windows-result" -version = "0.2.0" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" +checksum = "c64fd11a4fd95df68efcfee5f44a294fe71b8bc6a91993e2791938abcc712252" dependencies = [ - "windows-targets 0.52.6", + "windows-link", ] [[package]] name = "windows-strings" -version = "0.1.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" +checksum = "87fa48cc5d406560701792be122a10132491cff9d0aeb23583cc2dcafc847319" dependencies = [ - "windows-result", - "windows-targets 0.52.6", + "windows-link", ] [[package]] @@ -7961,13 +8144,29 @@ dependencies = [ "windows_aarch64_gnullvm 0.52.6", "windows_aarch64_msvc 0.52.6", "windows_i686_gnu 0.52.6", - "windows_i686_gnullvm", + "windows_i686_gnullvm 0.52.6", "windows_i686_msvc 0.52.6", "windows_x86_64_gnu 0.52.6", "windows_x86_64_gnullvm 0.52.6", "windows_x86_64_msvc 0.52.6", ] +[[package]] +name = "windows-targets" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1e4c7e8ceaaf9cb7d7507c974735728ab453b67ef8f18febdd7c11fe59dca8b" +dependencies = [ + "windows_aarch64_gnullvm 0.53.0", + "windows_aarch64_msvc 0.53.0", + "windows_i686_gnu 0.53.0", + "windows_i686_gnullvm 0.53.0", + "windows_i686_msvc 0.53.0", + "windows_x86_64_gnu 0.53.0", + "windows_x86_64_gnullvm 0.53.0", + "windows_x86_64_msvc 0.53.0", +] + [[package]] name = "windows_aarch64_gnullvm" version = "0.48.5" @@ -7980,6 +8179,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" @@ -7992,6 +8197,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" + [[package]] name = "windows_i686_gnu" version = "0.48.5" @@ -8004,12 +8215,24 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" +[[package]] +name = "windows_i686_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" + [[package]] name = "windows_i686_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" + [[package]] name = "windows_i686_msvc" version = "0.48.5" @@ -8022,6 +8245,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" +[[package]] +name = "windows_i686_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" @@ -8034,6 +8263,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" @@ -8046,6 +8281,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" @@ -8058,6 +8299,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" + [[package]] name = "winnow" version = "0.5.40" diff --git a/Cargo.toml b/Cargo.toml index 543465ee..f37448fb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -303,13 +303,13 @@ lto = false [patch.crates-io] # It's diffcult to resovle different version with the same crate used in AppFlowy Frontend and the Client-API crate. # So using patch to workaround this issue. -collab = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" } -collab-entity = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" } -collab-folder = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" } -collab-document = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" } -collab-user = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" } -collab-database = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" } -collab-importer = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" } +collab = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a0e2cc07f50f17d1b6605c579e622a431e94998" } +collab-entity = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a0e2cc07f50f17d1b6605c579e622a431e94998" } +collab-folder = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a0e2cc07f50f17d1b6605c579e622a431e94998" } +collab-document = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a0e2cc07f50f17d1b6605c579e622a431e94998" } +collab-user = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a0e2cc07f50f17d1b6605c579e622a431e94998" } +collab-database = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a0e2cc07f50f17d1b6605c579e622a431e94998" } +collab-importer = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a0e2cc07f50f17d1b6605c579e622a431e94998" } [features] history = [] diff --git a/libs/appflowy-ai-client/src/dto.rs b/libs/appflowy-ai-client/src/dto.rs index 87061abd..2befc654 100644 --- a/libs/appflowy-ai-client/src/dto.rs +++ b/libs/appflowy-ai-client/src/dto.rs @@ -220,69 +220,6 @@ pub struct TranslateRowResponse { pub items: Vec>, } -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -#[serde(untagged)] -pub enum EmbeddingInput { - /// The string that will be turned into an embedding. - String(String), - /// The array of strings that will be turned into an embedding. - StringArray(Vec), -} - -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -#[serde(untagged)] -pub enum EmbeddingOutput { - Float(Vec), - Base64(String), -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct Embedding { - /// An integer representing the index of the embedding in the list of embeddings. - pub index: i32, - /// The embedding value, which is an instance of `EmbeddingOutput`. - pub embedding: EmbeddingOutput, -} - -/// https://platform.openai.com/docs/api-reference/embeddings -#[derive(Serialize, Deserialize, Debug)] -pub struct OpenAIEmbeddingResponse { - /// A string that is always set to "embedding". - pub object: String, - /// A list of `Embedding` objects. - pub data: Vec, - /// A string representing the model used to generate the embeddings. - pub model: String, - /// An integer representing the total number of tokens used. - pub usage: EmbeddingUsage, -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct EmbeddingUsage { - #[serde(default)] - pub prompt_tokens: i32, - pub total_tokens: i32, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -#[serde(rename_all = "lowercase")] -pub enum EmbeddingEncodingFormat { - Float, - Base64, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct EmbeddingRequest { - /// An instance of `EmbeddingInput` containing the data to be embedded. - pub input: EmbeddingInput, - /// A string representing the model to use for generating embeddings. - pub model: String, - /// An instance of `EmbeddingEncodingFormat` representing the format of the embedding. - pub encoding_format: EmbeddingEncodingFormat, - /// An integer representing the number of dimensions for the embedding. - pub dimensions: i32, -} - #[derive(Serialize, Deserialize, Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum EmbeddingModel { #[serde(rename = "text-embedding-3-small")] @@ -310,7 +247,7 @@ impl EmbeddingModel { } } - pub fn default_dimensions(&self) -> i32 { + pub fn default_dimensions(&self) -> u32 { match self { EmbeddingModel::TextEmbeddingAda002 => 1536, EmbeddingModel::TextEmbedding3Large => 3072, diff --git a/libs/indexer/Cargo.toml b/libs/indexer/Cargo.toml index a492a249..af94a682 100644 --- a/libs/indexer/Cargo.toml +++ b/libs/indexer/Cargo.toml @@ -35,3 +35,4 @@ redis = { workspace = true, features = [ secrecy = { workspace = true, features = ["serde"] } reqwest.workspace = true twox-hash = { version = "2.1.0", features = ["xxhash64"] } +async-openai = "0.28.0" \ No newline at end of file diff --git a/libs/indexer/src/collab_indexer/document_indexer.rs b/libs/indexer/src/collab_indexer/document_indexer.rs index 60a994e9..c6cefac6 100644 --- a/libs/indexer/src/collab_indexer/document_indexer.rs +++ b/libs/indexer/src/collab_indexer/document_indexer.rs @@ -1,11 +1,10 @@ use crate::collab_indexer::Indexer; -use crate::vector::embedder::Embedder; +use crate::vector::embedder::AFEmbedder; use crate::vector::open_ai::group_paragraphs_by_max_content_len; use anyhow::anyhow; use app_error::AppError; -use appflowy_ai_client::dto::{ - EmbeddingEncodingFormat, EmbeddingInput, EmbeddingModel, EmbeddingOutput, EmbeddingRequest, -}; +use appflowy_ai_client::dto::EmbeddingModel; +use async_openai::types::{CreateEmbeddingRequestArgs, EmbeddingInput, EncodingFormat}; use async_trait::async_trait; use collab::preclude::Collab; use collab_document::document::DocumentBody; @@ -48,7 +47,7 @@ impl Indexer for DocumentIndexer { async fn embed( &self, - embedder: &Embedder, + embedder: &AFEmbedder, mut content: Vec, ) -> Result, AppError> { if content.is_empty() { @@ -59,14 +58,16 @@ impl Indexer for DocumentIndexer { .iter() .map(|fragment| fragment.content.clone().unwrap_or_default()) .collect(); - let resp = embedder - .async_embed(EmbeddingRequest { - input: EmbeddingInput::StringArray(contents), - model: embedder.model().name().to_string(), - encoding_format: EmbeddingEncodingFormat::Float, - dimensions: EmbeddingModel::TextEmbedding3Small.default_dimensions(), - }) - .await?; + + let request = CreateEmbeddingRequestArgs::default() + .model(embedder.model().name()) + .input(EmbeddingInput::StringArray(contents)) + .encoding_format(EncodingFormat::Float) + .dimensions(EmbeddingModel::TextEmbedding3Small.default_dimensions()) + .build() + .map_err(|err| AppError::Unhandled(err.to_string()))?; + + let resp = embedder.async_embed(request).await?; trace!( "[Embedding] request {} embeddings, received {} embeddings", @@ -77,21 +78,12 @@ impl Indexer for DocumentIndexer { for embedding in resp.data { let param = &mut content[embedding.index as usize]; if param.content.is_some() { - // we only set the embedding if the content was not marked as unchanged - let embedding: Vec = match embedding.embedding { - EmbeddingOutput::Float(embedding) => embedding.into_iter().map(|f| f as f32).collect(), - EmbeddingOutput::Base64(_) => { - return Err(AppError::OpenError( - "Unexpected base64 encoding".to_string(), - )) - }, - }; - param.embedding = Some(embedding); + param.embedding = Some(embedding.embedding); } } Ok(Some(AFCollabEmbeddings { - tokens_consumed: resp.usage.total_tokens as u32, + tokens_consumed: resp.usage.total_tokens, params: content, })) } diff --git a/libs/indexer/src/collab_indexer/provider.rs b/libs/indexer/src/collab_indexer/provider.rs index 3004bec5..0262c441 100644 --- a/libs/indexer/src/collab_indexer/provider.rs +++ b/libs/indexer/src/collab_indexer/provider.rs @@ -1,5 +1,5 @@ use crate::collab_indexer::DocumentIndexer; -use crate::vector::embedder::Embedder; +use crate::vector::embedder::AFEmbedder; use app_error::AppError; use appflowy_ai_client::dto::EmbeddingModel; use async_trait::async_trait; @@ -29,7 +29,7 @@ pub trait Indexer: Send + Sync { async fn embed( &self, - embedder: &Embedder, + embedder: &AFEmbedder, content: Vec, ) -> Result, AppError>; } diff --git a/libs/indexer/src/scheduler.rs b/libs/indexer/src/scheduler.rs index 5eae189c..969a1f9f 100644 --- a/libs/indexer/src/scheduler.rs +++ b/libs/indexer/src/scheduler.rs @@ -2,10 +2,11 @@ use crate::collab_indexer::IndexerProvider; use crate::entity::EmbeddingRecord; use crate::metrics::EmbeddingMetrics; use crate::queue::add_background_embed_task; -use crate::vector::embedder::Embedder; +use crate::vector::embedder::AFEmbedder; use crate::vector::open_ai; use app_error::AppError; -use appflowy_ai_client::dto::{EmbeddingRequest, OpenAIEmbeddingResponse}; +use async_openai::config::{AzureConfig, OpenAIConfig}; +use async_openai::types::{CreateEmbeddingRequest, CreateEmbeddingResponse}; use collab::preclude::Collab; use collab_document::document::DocumentBody; use collab_entity::CollabType; @@ -16,7 +17,6 @@ use database::index::{ use database::workspace::select_workspace_settings; use infra::env_util::get_env_var; use redis::aio::ConnectionManager; -use secrecy::{ExposeSecret, Secret}; use serde::{Deserialize, Serialize}; use sqlx::PgPool; use std::cmp::max; @@ -48,7 +48,8 @@ pub struct IndexerScheduler { #[derive(Debug)] pub struct IndexerConfiguration { pub enable: bool, - pub openai_api_key: Secret, + pub open_ai_config: Option, + pub azure_ai_config: Option, /// High watermark for the number of embeddings that can be buffered before being written to the database. pub embedding_buffer_size: usize, } @@ -114,39 +115,36 @@ impl IndexerScheduler { } fn index_enabled(&self) -> bool { - // if indexing is disabled, return false - if !self.config.enable { - return false; - } - - // if openai api key is empty, return false - if self.config.openai_api_key.expose_secret().is_empty() { - return false; - } - - true + self.config.enable + && (self.config.open_ai_config.is_some() || self.config.azure_ai_config.is_some()) } pub fn is_indexing_enabled(&self, collab_type: CollabType) -> bool { self.indexer_provider.is_indexing_enabled(collab_type) } - pub(crate) fn create_embedder(&self) -> Result { - if self.config.openai_api_key.expose_secret().is_empty() { - return Err(AppError::AIServiceUnavailable( - "OpenAI API key is empty".to_string(), - )); + pub(crate) fn create_embedder(&self) -> Result { + if let Some(config) = &self.config.azure_ai_config { + return Ok(AFEmbedder::AzureOpenAI(open_ai::AzureOpenAIEmbedder::new( + config.clone(), + ))); } - Ok(Embedder::OpenAI(open_ai::Embedder::new( - self.config.openai_api_key.expose_secret().clone(), - ))) + if let Some(config) = &self.config.open_ai_config { + return Ok(AFEmbedder::OpenAI(open_ai::OpenAIEmbedder::new( + config.clone(), + ))); + } + + Err(AppError::AIServiceUnavailable( + "No embedder available".to_string(), + )) } pub async fn create_search_embeddings( &self, - request: EmbeddingRequest, - ) -> Result { + request: CreateEmbeddingRequest, + ) -> Result { let embedder = self.create_embedder()?; let embeddings = embedder.async_embed(request).await?; Ok(embeddings) @@ -327,14 +325,12 @@ async fn generate_embeddings_loop( match embedder { Ok(embedder) => { let params: Vec<_> = records.iter().map(|r| r.object_id).collect(); - let existing_embeddings = - match get_collab_embedding_fragment_ids(&scheduler.pg_pool, params).await { - Ok(existing_embeddings) => existing_embeddings, - Err(err) => { - error!("[Embedding] failed to get existing embeddings: {}", err); - Default::default() - }, - }; + let existing_embeddings = get_collab_embedding_fragment_ids(&scheduler.pg_pool, params) + .await + .unwrap_or_else(|err| { + error!("[Embedding] failed to get existing embeddings: {}", err); + Default::default() + }); let mut join_set = JoinSet::new(); for record in records { if let Some(indexer) = indexer_provider.indexer_for(record.collab_type) { diff --git a/libs/indexer/src/unindexed_workspace.rs b/libs/indexer/src/unindexed_workspace.rs index 8692544b..d7fec750 100644 --- a/libs/indexer/src/unindexed_workspace.rs +++ b/libs/indexer/src/unindexed_workspace.rs @@ -1,7 +1,7 @@ use crate::collab_indexer::IndexerProvider; use crate::entity::{EmbeddingRecord, UnindexedCollab}; use crate::scheduler::{batch_insert_records, IndexerScheduler}; -use crate::vector::embedder::Embedder; +use crate::vector::embedder::AFEmbedder; use appflowy_ai_client::dto::EmbeddingModel; use collab::core::collab::DataSource; use collab::core::origin::CollabOrigin; @@ -157,7 +157,7 @@ async fn stream_unindexed_collabs( .boxed() } async fn create_embeddings( - embedder: Embedder, + embedder: AFEmbedder, indexer_provider: &Arc, unindexed_records: Vec, existing_embeddings: HashMap>, diff --git a/libs/indexer/src/vector/embedder.rs b/libs/indexer/src/vector/embedder.rs index 81554716..cd4d8f0c 100644 --- a/libs/indexer/src/vector/embedder.rs +++ b/libs/indexer/src/vector/embedder.rs @@ -1,24 +1,28 @@ use crate::vector::open_ai; +use crate::vector::open_ai::async_embed; use app_error::AppError; -use appflowy_ai_client::dto::{EmbeddingModel, EmbeddingRequest, OpenAIEmbeddingResponse}; +use appflowy_ai_client::dto::EmbeddingModel; +pub use async_openai::config::{AzureConfig, OpenAIConfig}; +pub use async_openai::types::{ + CreateEmbeddingRequest, CreateEmbeddingRequestArgs, CreateEmbeddingResponse, EmbeddingInput, + EncodingFormat, +}; +use infra::env_util::get_env_var_opt; #[derive(Debug, Clone)] -pub enum Embedder { - OpenAI(open_ai::Embedder), +pub enum AFEmbedder { + OpenAI(open_ai::OpenAIEmbedder), + AzureOpenAI(open_ai::AzureOpenAIEmbedder), } -impl Embedder { - pub fn embed(&self, params: EmbeddingRequest) -> Result { - match self { - Self::OpenAI(embedder) => embedder.embed(params), - } - } +impl AFEmbedder { pub async fn async_embed( &self, - params: EmbeddingRequest, - ) -> Result { + params: CreateEmbeddingRequest, + ) -> Result { match self { - Self::OpenAI(embedder) => embedder.async_embed(params).await, + Self::OpenAI(embedder) => async_embed(&embedder.client, params).await, + Self::AzureOpenAI(embedder) => async_embed(&embedder.client, params).await, } } @@ -26,3 +30,20 @@ impl Embedder { EmbeddingModel::TextEmbedding3Small } } + +pub fn open_ai_config() -> Option { + get_env_var_opt("AI_OPENAI_API_KEY").map(|v| OpenAIConfig::default().with_api_key(v)) +} + +pub fn azure_open_ai_config() -> Option { + let azure_open_ai_api_key = get_env_var_opt("AI_AZURE_OPENAI_API_KEY")?; + let azure_open_ai_api_base = get_env_var_opt("AI_AZURE_OPENAI_API_BASE")?; + let azure_open_ai_api_version = get_env_var_opt("AI_AZURE_OPENAI_API_VERSION")?; + + Some( + AzureConfig::new() + .with_api_key(azure_open_ai_api_key) + .with_api_base(azure_open_ai_api_base) + .with_api_version(azure_open_ai_api_version), + ) +} diff --git a/libs/indexer/src/vector/mod.rs b/libs/indexer/src/vector/mod.rs index 51a449fb..f8a6d1a8 100644 --- a/libs/indexer/src/vector/mod.rs +++ b/libs/indexer/src/vector/mod.rs @@ -1,3 +1,2 @@ pub mod embedder; pub mod open_ai; -mod rest; diff --git a/libs/indexer/src/vector/open_ai.rs b/libs/indexer/src/vector/open_ai.rs index a49e7d72..4aefda8e 100644 --- a/libs/indexer/src/vector/open_ai.rs +++ b/libs/indexer/src/vector/open_ai.rs @@ -1,9 +1,7 @@ -use crate::vector::rest::check_ureq_response; -use anyhow::anyhow; use app_error::AppError; -use appflowy_ai_client::dto::{EmbeddingRequest, OpenAIEmbeddingResponse}; -use serde::de::DeserializeOwned; -use std::time::Duration; +use async_openai::config::{AzureConfig, Config, OpenAIConfig}; +use async_openai::types::{CreateEmbeddingRequest, CreateEmbeddingResponse}; +use async_openai::Client; use tiktoken_rs::CoreBPE; pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings"; @@ -11,98 +9,42 @@ pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings"; pub const REQUEST_PARALLELISM: usize = 40; #[derive(Debug, Clone)] -pub struct Embedder { - bearer: String, - sync_client: ureq::Agent, - async_client: reqwest::Client, +pub struct OpenAIEmbedder { + pub(crate) client: Client, } -impl Embedder { - pub fn new(api_key: String) -> Self { - let bearer = format!("Bearer {api_key}"); - let sync_client = ureq::AgentBuilder::new() - .max_idle_connections(REQUEST_PARALLELISM * 2) - .max_idle_connections_per_host(REQUEST_PARALLELISM * 2) - .build(); +impl OpenAIEmbedder { + pub fn new(config: OpenAIConfig) -> Self { + let client = Client::with_config(config); - let async_client = reqwest::Client::builder().build().unwrap(); - - Self { - bearer, - sync_client, - async_client, - } - } - - pub fn embed(&self, params: EmbeddingRequest) -> Result { - for attempt in 0..3 { - let request = self - .sync_client - .post(OPENAI_EMBEDDINGS_URL) - .set("Authorization", &self.bearer) - .set("Content-Type", "application/json"); - - let result = check_ureq_response(request.send_json(¶ms)); - let retry_duration = match result { - Ok(response) => { - let data = from_ureq_response::(response)?; - return Ok(data); - }, - Err(retry) => retry.into_duration(attempt), - } - .map_err(|err| AppError::Internal(err.into()))?; - let retry_duration = retry_duration.min(Duration::from_secs(10)); - std::thread::sleep(retry_duration); - } - - Err(AppError::Internal(anyhow!( - "Failed to generate embeddings after 3 attempts" - ))) - } - - pub async fn async_embed( - &self, - params: EmbeddingRequest, - ) -> Result { - let request = self - .async_client - .post(OPENAI_EMBEDDINGS_URL) - .header("Authorization", &self.bearer) - .header("Content-Type", "application/json"); - - let result = request.json(¶ms).send().await?; - let response = from_response::(result).await?; - Ok(response) + Self { client } } } -pub fn from_ureq_response(resp: ureq::Response) -> Result -where - T: DeserializeOwned, -{ - let status_code = resp.status(); - if status_code != 200 { - let body = resp.into_string()?; - anyhow::bail!("error code: {}, {}", status_code, body) - } - - let resp = resp.into_json()?; - Ok(resp) +#[derive(Debug, Clone)] +pub struct AzureOpenAIEmbedder { + pub(crate) client: Client, } -pub async fn from_response(resp: reqwest::Response) -> Result -where - T: DeserializeOwned, -{ - let status_code = resp.status(); - if status_code != 200 { - let body = resp.text().await?; - anyhow::bail!("error code: {}, {}", status_code, body) +impl AzureOpenAIEmbedder { + pub fn new(config: AzureConfig) -> Self { + let client = Client::with_config(config); + Self { client } } - - let resp = resp.json().await?; - Ok(resp) } + +pub async fn async_embed( + client: &Client, + request: CreateEmbeddingRequest, +) -> Result { + let response = client + .embeddings() + .create(request) + .await + .map_err(|err| AppError::Unhandled(err.to_string()))?; + Ok(response) +} + /// ## Execution Time Comparison Results /// /// The following results were observed when running `execution_time_comparison_tests`: diff --git a/libs/indexer/src/vector/rest.rs b/libs/indexer/src/vector/rest.rs deleted file mode 100644 index 17dc20ac..00000000 --- a/libs/indexer/src/vector/rest.rs +++ /dev/null @@ -1,189 +0,0 @@ -use crate::thread_pool::CatchedPanic; - -#[derive(Debug, thiserror::Error)] -#[error("{fault}: {kind}")] -pub struct EmbedError { - pub kind: EmbedErrorKind, - pub fault: FaultSource, -} - -impl EmbedError { - pub(crate) fn rest_unauthorized(error_response: Option) -> EmbedError { - Self { - kind: EmbedErrorKind::RestUnauthorized(error_response), - fault: FaultSource::User, - } - } - - pub(crate) fn rest_too_many_requests(error_response: Option) -> EmbedError { - Self { - kind: EmbedErrorKind::RestTooManyRequests(error_response), - fault: FaultSource::Runtime, - } - } - - pub(crate) fn rest_bad_request(error_response: Option) -> EmbedError { - Self { - kind: EmbedErrorKind::RestBadRequest(error_response), - fault: FaultSource::User, - } - } - - pub(crate) fn rest_internal_server_error( - code: u16, - error_response: Option, - ) -> EmbedError { - Self { - kind: EmbedErrorKind::RestInternalServerError(code, error_response), - fault: FaultSource::Runtime, - } - } - - pub(crate) fn rest_other_status_code(code: u16, error_response: Option) -> EmbedError { - Self { - kind: EmbedErrorKind::RestOtherStatusCode(code, error_response), - fault: FaultSource::Unhandled, - } - } - - pub(crate) fn rest_network(transport: ureq::Transport) -> EmbedError { - Self { - kind: EmbedErrorKind::RestNetwork(transport), - fault: FaultSource::Runtime, - } - } -} - -#[derive(Debug, thiserror::Error)] -pub enum EmbedErrorKind { - #[error("could not authenticate against {}", option_info(.0.as_deref(), "server replied with "))] - RestUnauthorized(Option), - - #[error("sent too many requests to embedding server{}", option_info(.0.as_deref(), "server replied with "))] - RestTooManyRequests(Option), - - #[error("received bad request HTTP from embedding server{}", option_info(.0.as_deref(), "server replied with "))] - RestBadRequest(Option), - - #[error("received internal error HTTP {0} from embedding server{}", option_info(.1.as_deref(), "server replied with "))] - RestInternalServerError(u16, Option), - - #[error("received unexpected HTTP {0} from embedding server{}", option_info(.1.as_deref(), "server replied with "))] - RestOtherStatusCode(u16, Option), - - #[error("could not reach embedding server:\n - {0}")] - RestNetwork(ureq::Transport), - - #[error(transparent)] - PanicInThreadPool(#[from] CatchedPanic), -} - -fn option_info(info: Option<&str>, prefix: &str) -> String { - match info { - Some(info) => format!("\n - {prefix}`{info}`"), - None => String::new(), - } -} - -#[derive(Debug, Clone, Copy)] -pub enum FaultSource { - User, - Runtime, - Unhandled, -} - -impl std::fmt::Display for FaultSource { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let s = match self { - FaultSource::User => "user error", - FaultSource::Runtime => "runtime error", - FaultSource::Unhandled => "error", - }; - f.write_str(s) - } -} - -pub struct Retry { - pub error: EmbedError, - strategy: RetryStrategy, -} - -pub enum RetryStrategy { - GiveUp, - Retry, - RetryAfterRateLimit, -} - -impl Retry { - pub fn give_up(error: EmbedError) -> Self { - Self { - error, - strategy: RetryStrategy::GiveUp, - } - } - - pub fn retry_later(error: EmbedError) -> Self { - Self { - error, - strategy: RetryStrategy::Retry, - } - } - - pub fn rate_limited(error: EmbedError) -> Self { - Self { - error, - strategy: RetryStrategy::RetryAfterRateLimit, - } - } - - /// Converts a retry strategy into a delay duration based on the number of retry attempts. - /// - /// # Retry Strategies - /// - /// - **GiveUp**: If the retry strategy is `GiveUp`, the function immediately returns an error, - /// indicating that further retries should not be attempted. The error is provided by `self.error`. - /// - /// - **Retry**: If the retry strategy is `Retry`, the function applies exponential backoff based - /// on the attempt number. The delay increases exponentially with each retry attempt (e.g., - /// `10^attempt` milliseconds). For the first retry, the delay will be 10 milliseconds, - /// for the second retry 100 milliseconds, and so on. - /// - /// - **RetryAfterRateLimit**: If the retry strategy is `RetryAfterRateLimit`, the function adds - /// a fixed delay of 100 milliseconds to the exponential backoff calculation. For example, - /// for the first retry, the delay will be `100 + 10^1 = 110` milliseconds, and for the second - /// retry, it will be `100 + 10^2 = 200` milliseconds. - pub fn into_duration(self, attempt: u32) -> Result> { - match self.strategy { - RetryStrategy::GiveUp => Err(Box::new(self.error)), - RetryStrategy::Retry => Ok(std::time::Duration::from_millis((10u64).pow(attempt))), - RetryStrategy::RetryAfterRateLimit => { - Ok(std::time::Duration::from_millis(100 + 10u64.pow(attempt))) - }, - } - } -} - -#[allow(clippy::result_large_err)] -pub(crate) fn check_ureq_response( - response: Result, -) -> Result { - match response { - Ok(response) => Ok(response), - Err(ureq::Error::Status(code, response)) => { - let error_response: Option = response.into_string().ok(); - Err(match code { - 401 => Retry::give_up(EmbedError::rest_unauthorized(error_response)), - 429 => Retry::rate_limited(EmbedError::rest_too_many_requests(error_response)), - 400 => Retry::give_up(EmbedError::rest_bad_request(error_response)), - 500..=599 => { - Retry::retry_later(EmbedError::rest_internal_server_error(code, error_response)) - }, - 402..=499 => Retry::give_up(EmbedError::rest_other_status_code(code, error_response)), - _ => Retry::retry_later(EmbedError::rest_other_status_code(code, error_response)), - }) - }, - Err(ureq::Error::Transport(transport)) => { - Err(Retry::retry_later(EmbedError::rest_network(transport))) - }, - } -} diff --git a/services/appflowy-collaborate/tests/indexer_test.rs b/services/appflowy-collaborate/tests/indexer_test.rs index 195bed0b..2b131a19 100644 --- a/services/appflowy-collaborate/tests/indexer_test.rs +++ b/services/appflowy-collaborate/tests/indexer_test.rs @@ -11,7 +11,7 @@ fn document_plain_text() { let collab = Collab::new_with_origin(CollabOrigin::Server, "1", vec![], false); let document = Document::create_with_data(collab, doc).unwrap(); let text = document.paragraphs().join(""); - let expected = "Welcome to AppFlowy $ Download for macOS, Windows, and Linux link $ $ $ quick start Ask AI powered by advanced AI models: chat, search, write, and much more ✨ ❤\u{fe0f}Love AppFlowy and open source? Follow our latest product updates: Twitter : @appflowy Reddit : r/appflowy Github "; + let expected = "Welcome to AppFlowy$Download for macOS, Windows, and Linuxlink$$$quick startAsk AI powered by advanced AI models: chat, search, write, and much more ✨❤\u{fe0f}Love AppFlowy and open source? Follow our latest product updates:Twitter: @appflowyReddit: r/appflowyGithub"; assert_eq!(&text, expected); } @@ -21,6 +21,6 @@ fn document_plain_text_with_nested_blocks() { let collab = Collab::new_with_origin(CollabOrigin::Server, "1", vec![], false); let document = Document::create_with_data(collab, doc).unwrap(); let text = document.paragraphs().join(""); - let expected = "Welcome to AppFlowy! Here are the basics Here is H3 Click anywhere and just start typing. Click Enter to create a new line. Highlight any text, and use the editing menu to style your writing however you like. As soon as you type / a menu will pop up. Select different types of content blocks you can add. Type / followed by /bullet or /num to create a list. Click + New Page button at the bottom of your sidebar to add a new page. Click + next to any page title in the sidebar to quickly add a new subpage, Document , Grid , or Kanban Board . Keyboard shortcuts, markdown, and code block Keyboard shortcuts guide Markdown reference Type /code to insert a code block // This is the main function.\nfn main() {\n // Print text to the console.\n println!(\"Hello World!\");\n} This is a paragraph This is a paragraph Have a question❓ Click ? at the bottom right for help and support. This is a paragraph This is a paragraph Click ? at the bottom right for help and support. Like AppFlowy? Follow us: GitHub Twitter : @appflowy Newsletter "; + let expected = "Welcome to AppFlowy!Here are the basicsHere is H3Click anywhere and just start typing.Click Enter to create a new line.Highlight any text, and use the editing menu to styleyourwritinghowever you like.As soon as you type / a menu will pop up. Select different types of content blocks you can add.Type / followed by /bullet or /num to create a list.Click + New Page button at the bottom of your sidebar to add a new page.Click + next to any page title in the sidebar to quickly add a new subpage, Document, Grid, or Kanban Board.Keyboard shortcuts, markdown, and code blockKeyboard shortcuts guideMarkdown referenceType /code to insert a code block// This is the main function.\nfn main() {\n // Print text to the console.\n println!(\"Hello World!\");\n}This is a paragraphThis is a paragraphHave a question❓Click ? at the bottom right for help and support.This is a paragraphThis is a paragraphClick ? at the bottom right for help and support. Like AppFlowy? Follow us: GitHubTwitter: @appflowy Newsletter"; assert_eq!(&text, expected); } diff --git a/services/appflowy-worker/src/application.rs b/services/appflowy-worker/src/application.rs index f94b59fb..88bb1fad 100644 --- a/services/appflowy-worker/src/application.rs +++ b/services/appflowy-worker/src/application.rs @@ -21,9 +21,10 @@ use axum::response::IntoResponse; use axum::routing::get; use indexer::metrics::EmbeddingMetrics; use indexer::thread_pool::ThreadPoolNoAbortBuilder; +use indexer::vector::embedder::{azure_open_ai_config, open_ai_config}; use infra::env_util::get_env_var; use mailer::sender::Mailer; -use secrecy::{ExposeSecret, Secret}; +use secrecy::ExposeSecret; use std::sync::{Arc, Once}; use std::time::Duration; use tokio::net::TcpListener; @@ -131,21 +132,24 @@ pub async fn create_app(listener: TcpListener, config: Config) -> Result<(), Err .unwrap(), ); + let open_ai_config = open_ai_config(); + let azure_ai_config = azure_open_ai_config(); + + let indexer_config = BackgroundIndexerConfig { + enable: appflowy_collaborate::config::get_env_var("APPFLOWY_INDEXER_ENABLED", "true") + .parse::() + .unwrap_or(true), + open_ai_config, + azure_ai_config, + tick_interval_secs: 10, + }; + tokio::spawn(run_background_indexer( state.pg_pool.clone(), state.redis_client.clone(), state.metrics.embedder_metrics.clone(), threads.clone(), - BackgroundIndexerConfig { - enable: appflowy_collaborate::config::get_env_var("APPFLOWY_INDEXER_ENABLED", "true") - .parse::() - .unwrap_or(true), - open_api_key: Secret::new(appflowy_collaborate::config::get_env_var( - "AI_OPENAI_API_KEY", - "", - )), - tick_interval_secs: 10, - }, + indexer_config, )); let app = Router::new() diff --git a/services/appflowy-worker/src/indexer_worker/worker.rs b/services/appflowy-worker/src/indexer_worker/worker.rs index f25c6f8b..66a0a304 100644 --- a/services/appflowy-worker/src/indexer_worker/worker.rs +++ b/services/appflowy-worker/src/indexer_worker/worker.rs @@ -10,10 +10,9 @@ use indexer::queue::{ }; use indexer::scheduler::{spawn_pg_write_embeddings, UnindexedCollabTask, UnindexedData}; use indexer::thread_pool::ThreadPoolNoAbort; -use indexer::vector::embedder::Embedder; +use indexer::vector::embedder::{AFEmbedder, AzureConfig, OpenAIConfig}; use indexer::vector::open_ai; use redis::aio::ConnectionManager; -use secrecy::{ExposeSecret, Secret}; use sqlx::PgPool; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -25,7 +24,8 @@ use tracing::{error, info, trace, warn}; pub struct BackgroundIndexerConfig { pub enable: bool, - pub open_api_key: Secret, + pub open_ai_config: Option, + pub azure_ai_config: Option, pub tick_interval_secs: u64, } @@ -41,7 +41,7 @@ pub async fn run_background_indexer( return; } - if config.open_api_key.expose_secret().is_empty() { + if config.open_ai_config.is_none() && config.azure_ai_config.is_none() { error!("OpenAI API key is not set. Stop background indexer"); return; } @@ -160,50 +160,51 @@ async fn process_upcoming_tasks( let mut join_set = JoinSet::new(); for task in tasks { if let Some(indexer) = indexer_provider.indexer_for(task.collab_type) { - let embedder = create_embedder(&config); - trace!( - "[Background Embedding] processing task: {}, content:{:?}, collab_type: {}", - task.object_id, - task.data, - task.collab_type - ); - let paragraphs = match task.data { - UnindexedData::Paragraphs(paragraphs) => paragraphs, - UnindexedData::Text(text) => text.split('\n').map(|s| s.to_string()).collect(), - }; - let mut chunks = match indexer.create_embedded_chunks_from_text( - task.object_id, - paragraphs, - embedder.model(), - ) { - Ok(chunks) => chunks, - Err(err) => { - warn!( + if let Ok(embedder) = create_embedder(&config) { + trace!( + "[Background Embedding] processing task: {}, content:{:?}, collab_type: {}", + task.object_id, + task.data, + task.collab_type + ); + let paragraphs = match task.data { + UnindexedData::Paragraphs(paragraphs) => paragraphs, + UnindexedData::Text(text) => text.split('\n').map(|s| s.to_string()).collect(), + }; + let mut chunks = match indexer.create_embedded_chunks_from_text( + task.object_id, + paragraphs, + embedder.model(), + ) { + Ok(chunks) => chunks, + Err(err) => { + warn!( "[Background Embedding] failed to create embedded chunks for task: {}, error: {:?}", task.object_id, err ); - continue; - }, - }; - if let Some(existing_chunks) = existing_embeddings.get(&task.object_id) { - for chunk in chunks.iter_mut() { - if existing_chunks.contains(&chunk.fragment_id) { - chunk.content = None; // Clear content to mark unchanged chunk - chunk.embedding = None; + continue; + }, + }; + if let Some(existing_chunks) = existing_embeddings.get(&task.object_id) { + for chunk in chunks.iter_mut() { + if existing_chunks.contains(&chunk.fragment_id) { + chunk.content = None; // Clear content to mark unchanged chunk + chunk.embedding = None; + } } } + join_set.spawn(async move { + let embeddings = indexer.embed(&embedder, chunks).await.ok()?; + embeddings.map(|embeddings| EmbeddingRecord { + workspace_id: task.workspace_id, + object_id: task.object_id, + collab_type: task.collab_type, + tokens_used: embeddings.tokens_consumed, + contents: embeddings.params, + }) + }); } - join_set.spawn(async move { - let embeddings = indexer.embed(&embedder, chunks).await.ok()?; - embeddings.map(|embeddings| EmbeddingRecord { - workspace_id: task.workspace_id, - object_id: task.object_id, - collab_type: task.collab_type, - tokens_used: embeddings.tokens_consumed, - contents: embeddings.params, - }) - }); } } @@ -254,8 +255,20 @@ async fn process_upcoming_tasks( } } -fn create_embedder(config: &BackgroundIndexerConfig) -> Embedder { - Embedder::OpenAI(open_ai::Embedder::new( - config.open_api_key.expose_secret().clone(), +fn create_embedder(config: &BackgroundIndexerConfig) -> Result { + if let Some(config) = &config.azure_ai_config { + return Ok(AFEmbedder::AzureOpenAI(open_ai::AzureOpenAIEmbedder::new( + config.clone(), + ))); + } + + if let Some(config) = &config.open_ai_config { + return Ok(AFEmbedder::OpenAI(open_ai::OpenAIEmbedder::new( + config.clone(), + ))); + } + + Err(AppError::AIServiceUnavailable( + "No embedder available".to_string(), )) } diff --git a/src/application.rs b/src/application.rs index 76b34536..6100e427 100644 --- a/src/application.rs +++ b/src/application.rs @@ -27,7 +27,7 @@ use aws_sdk_s3::types::{ BucketInfo, BucketLocationConstraint, BucketType, CreateBucketConfiguration, }; use mailer::config::MailerSetting; -use secrecy::{ExposeSecret, Secret}; +use secrecy::ExposeSecret; use sqlx::{postgres::PgPoolOptions, PgPool}; use tokio::sync::RwLock; use tracing::{error, info}; @@ -45,6 +45,7 @@ use collab_stream::stream_router::{StreamRouter, StreamRouterOptions}; use database::file::s3_client_impl::{AwsS3BucketClientImpl, S3BucketStorage}; use indexer::collab_indexer::IndexerProvider; use indexer::scheduler::{IndexerConfiguration, IndexerScheduler}; +use indexer::vector::embedder::{azure_open_ai_config, open_ai_config}; use infra::env_util::get_env_var; use mailer::sender::Mailer; use snowflake::Snowflake; @@ -298,11 +299,14 @@ pub async fn init_state(config: &Config, rt_cmd_tx: CLCommandSender) -> Result() .unwrap_or(true), - openai_api_key: Secret::new(get_env_var("AI_OPENAI_API_KEY", "")), + open_ai_config, + azure_ai_config, embedding_buffer_size: appflowy_collaborate::config::get_env_var( "APPFLOWY_INDEXER_EMBEDDING_BUFFER_SIZE", "5000", diff --git a/src/biz/search/ops.rs b/src/biz/search/ops.rs index 49633002..e52e0bd2 100644 --- a/src/biz/search/ops.rs +++ b/src/biz/search/ops.rs @@ -4,9 +4,7 @@ use crate::{ api::metrics::RequestMetrics, biz::collab::folder_view::private_space_and_trash_view_ids, }; use app_error::AppError; -use appflowy_ai_client::dto::{ - EmbeddingEncodingFormat, EmbeddingInput, EmbeddingModel, EmbeddingOutput, EmbeddingRequest, -}; +use appflowy_ai_client::dto::EmbeddingModel; use appflowy_collaborate::collab::storage::CollabAccessControlStorage; use collab_folder::{Folder, View}; use database::collab::GetCollabOrigin; @@ -20,6 +18,7 @@ use shared_entity::dto::search_dto::{ use sqlx::PgPool; use indexer::scheduler::IndexerScheduler; +use indexer::vector::embedder::{CreateEmbeddingRequestArgs, EmbeddingInput, EncodingFormat}; use uuid::Uuid; static MAX_SEARCH_DEPTH: i32 = 10; @@ -80,15 +79,18 @@ pub async fn search_document( request: SearchDocumentRequest, metrics: &RequestMetrics, ) -> Result, AppError> { - let embeddings = indexer_scheduler - .create_search_embeddings(EmbeddingRequest { - input: EmbeddingInput::String(request.query.clone()), - model: EmbeddingModel::TextEmbedding3Small.to_string(), - encoding_format: EmbeddingEncodingFormat::Float, - dimensions: EmbeddingModel::TextEmbedding3Small.default_dimensions(), - }) + let embeddings_request = CreateEmbeddingRequestArgs::default() + .model(EmbeddingModel::TextEmbedding3Small.to_string()) + .input(EmbeddingInput::String(request.query.clone())) + .encoding_format(EncodingFormat::Float) + .dimensions(EmbeddingModel::TextEmbedding3Small.default_dimensions()) + .build() + .map_err(|err| AppError::Unhandled(err.to_string()))?; + + let mut embeddings_resp = indexer_scheduler + .create_search_embeddings(embeddings_request) .await?; - let total_tokens = embeddings.usage.total_tokens as u32; + let total_tokens = embeddings_resp.usage.total_tokens; metrics.record_search_tokens_used(&workspace_uuid, total_tokens); tracing::info!( "workspace {} OpenAI API search tokens used: {}", @@ -96,18 +98,10 @@ pub async fn search_document( total_tokens ); - let embedding = embeddings + let embedding = embeddings_resp .data - .first() + .pop() .ok_or_else(|| AppError::Internal(anyhow::anyhow!("OpenAI returned no embeddings")))?; - let embedding = match &embedding.embedding { - EmbeddingOutput::Float(vector) => vector.iter().map(|&v| v as f32).collect(), - EmbeddingOutput::Base64(_) => { - return Err(AppError::Internal(anyhow::anyhow!( - "OpenAI returned embeddings in unsupported format" - ))) - }, - }; let folder = get_latest_collab_folder( collab_storage, @@ -133,7 +127,7 @@ pub async fn search_document( workspace_id: workspace_uuid, limit: request.limit.unwrap_or(10) as i32, preview: request.preview_size.unwrap_or(500) as i32, - embedding, + embedding: embedding.embedding, searchable_view_ids: searchable_view_ids.into_iter().collect(), }, total_tokens,