chore: support azure open ai (#1321)

This commit is contained in:
Nathan.fooo 2025-04-06 19:11:40 +08:00 committed by GitHub
parent 1fd900d994
commit 2c1c820e68
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 549 additions and 588 deletions

419
Cargo.lock generated
View file

@ -399,7 +399,7 @@ dependencies = [
"sha2",
"shared-entity",
"tokio",
"tower",
"tower 0.4.13",
"tower-http",
"tower-service",
"tracing",
@ -648,7 +648,7 @@ dependencies = [
"reqwest",
"sanitize-filename",
"scraper",
"secrecy",
"secrecy 0.8.0",
"semver",
"serde",
"serde_json",
@ -712,7 +712,7 @@ dependencies = [
"rand 0.8.5",
"rayon",
"redis 0.25.4",
"secrecy",
"secrecy 0.8.0",
"semver",
"serde",
"serde_json",
@ -760,7 +760,7 @@ dependencies = [
"prometheus-client",
"rayon",
"redis 0.25.4",
"secrecy",
"secrecy 0.8.0",
"serde",
"serde_json",
"sqlx",
@ -880,6 +880,43 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "async-openai"
version = "0.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "36c566b15aa847e60a9e6c9b9b4b9d4be94bbf776804624279afa69559fea7e1"
dependencies = [
"async-openai-macros",
"backoff",
"base64 0.22.1",
"bytes",
"derive_builder",
"eventsource-stream",
"futures",
"rand 0.8.5",
"reqwest",
"reqwest-eventsource",
"secrecy 0.10.3",
"serde",
"serde_json",
"thiserror 2.0.12",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
]
[[package]]
name = "async-openai-macros"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0289cba6d5143bfe8251d57b4a8cac036adf158525a76533a7082ba65ec76398"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.90",
]
[[package]]
name = "async-recursion"
version = "1.1.1"
@ -965,7 +1002,7 @@ dependencies = [
"argon2",
"gotrue-entity",
"rand 0.8.5",
"secrecy",
"secrecy 0.8.0",
"serde",
"sqlx",
"thiserror 1.0.63",
@ -1273,7 +1310,7 @@ dependencies = [
"serde_urlencoded",
"sync_wrapper 1.0.1",
"tokio",
"tower",
"tower 0.4.13",
"tower-layer",
"tower-service",
"tracing",
@ -1317,12 +1354,26 @@ dependencies = [
"mime",
"pin-project-lite",
"serde",
"tower",
"tower 0.4.13",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "backoff"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
dependencies = [
"futures-core",
"getrandom 0.2.15",
"instant",
"pin-project-lite",
"rand 0.8.5",
"tokio",
]
[[package]]
name = "backtrace"
version = "0.3.73"
@ -1535,7 +1586,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40723b8fb387abc38f4f4a37c09073622e41dd12327033091ef8950659e6dc0c"
dependencies = [
"memchr",
"regex-automata 0.4.7",
"regex-automata 0.4.9",
"serde",
]
@ -1888,7 +1939,7 @@ dependencies = [
[[package]]
name = "collab"
version = "0.2.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b#3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=4a0e2cc07f50f17d1b6605c579e622a431e94998#4a0e2cc07f50f17d1b6605c579e622a431e94998"
dependencies = [
"anyhow",
"arc-swap",
@ -1913,7 +1964,7 @@ dependencies = [
[[package]]
name = "collab-database"
version = "0.2.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b#3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=4a0e2cc07f50f17d1b6605c579e622a431e94998#4a0e2cc07f50f17d1b6605c579e622a431e94998"
dependencies = [
"anyhow",
"async-trait",
@ -1953,7 +2004,7 @@ dependencies = [
[[package]]
name = "collab-document"
version = "0.2.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b#3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=4a0e2cc07f50f17d1b6605c579e622a431e94998#4a0e2cc07f50f17d1b6605c579e622a431e94998"
dependencies = [
"anyhow",
"arc-swap",
@ -1974,7 +2025,7 @@ dependencies = [
[[package]]
name = "collab-entity"
version = "0.2.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b#3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=4a0e2cc07f50f17d1b6605c579e622a431e94998#4a0e2cc07f50f17d1b6605c579e622a431e94998"
dependencies = [
"anyhow",
"bytes",
@ -1994,7 +2045,7 @@ dependencies = [
[[package]]
name = "collab-folder"
version = "0.2.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b#3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=4a0e2cc07f50f17d1b6605c579e622a431e94998#4a0e2cc07f50f17d1b6605c579e622a431e94998"
dependencies = [
"anyhow",
"arc-swap",
@ -2016,7 +2067,7 @@ dependencies = [
[[package]]
name = "collab-importer"
version = "0.1.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b#3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=4a0e2cc07f50f17d1b6605c579e622a431e94998#4a0e2cc07f50f17d1b6605c579e622a431e94998"
dependencies = [
"anyhow",
"async-recursion",
@ -2124,7 +2175,7 @@ dependencies = [
[[package]]
name = "collab-user"
version = "0.2.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b#3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=4a0e2cc07f50f17d1b6605c579e622a431e94998#4a0e2cc07f50f17d1b6605c579e622a431e94998"
dependencies = [
"anyhow",
"collab",
@ -2652,6 +2703,37 @@ dependencies = [
"powerfmt",
]
[[package]]
name = "derive_builder"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947"
dependencies = [
"derive_builder_macro",
]
[[package]]
name = "derive_builder_core"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8"
dependencies = [
"darling",
"proc-macro2",
"quote",
"syn 2.0.90",
]
[[package]]
name = "derive_builder_macro"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c"
dependencies = [
"derive_builder_core",
"syn 2.0.90",
]
[[package]]
name = "derive_more"
version = "0.99.18"
@ -2847,6 +2929,17 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "eventsource-stream"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab"
dependencies = [
"futures-core",
"nom",
"pin-project-lite",
]
[[package]]
name = "fancy-regex"
version = "0.11.0"
@ -2864,8 +2957,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2"
dependencies = [
"bit-set",
"regex-automata 0.4.7",
"regex-syntax 0.8.4",
"regex-automata 0.4.9",
"regex-syntax 0.8.5",
]
[[package]]
@ -3552,7 +3645,7 @@ dependencies = [
"hyper 0.14.30",
"log",
"rustls 0.21.12",
"rustls-native-certs",
"rustls-native-certs 0.6.3",
"tokio",
"tokio-rustls 0.24.1",
]
@ -3568,6 +3661,7 @@ dependencies = [
"hyper 1.4.1",
"hyper-util",
"rustls 0.23.20",
"rustls-native-certs 0.7.3",
"rustls-pki-types",
"tokio",
"tokio-rustls 0.26.0",
@ -3823,6 +3917,7 @@ dependencies = [
"anyhow",
"app-error",
"appflowy-ai-client",
"async-openai",
"async-trait",
"chrono",
"collab",
@ -3836,7 +3931,7 @@ dependencies = [
"rayon",
"redis 0.25.4",
"reqwest",
"secrecy",
"secrecy 0.8.0",
"serde",
"serde_json",
"sqlx",
@ -3927,6 +4022,15 @@ dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "1.0.11"
@ -3953,10 +4057,11 @@ dependencies = [
[[package]]
name = "js-sys"
version = "0.3.69"
version = "0.3.77"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d"
checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f"
dependencies = [
"once_cell",
"wasm-bindgen",
]
@ -4034,9 +4139,9 @@ dependencies = [
[[package]]
name = "libc"
version = "0.2.155"
version = "0.2.171"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
checksum = "c19937216e9d3aa9956d9bb8dfc0b0c8beb6058fc4f7a4dc4d850edf86a237d6"
[[package]]
name = "libm"
@ -4149,7 +4254,7 @@ dependencies = [
"anyhow",
"handlebars",
"lettre",
"secrecy",
"secrecy 0.8.0",
"serde",
]
@ -5116,7 +5221,7 @@ checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15"
dependencies = [
"bytes",
"heck 0.5.0",
"itertools 0.12.1",
"itertools 0.13.0",
"log",
"multimap",
"once_cell",
@ -5149,7 +5254,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9552f850d5f0964a4e4d0bf306459ac29323ddfbae05e35a7c0d35cb0803cc5"
dependencies = [
"anyhow",
"itertools 0.12.1",
"itertools 0.13.0",
"proc-macro2",
"quote",
"syn 2.0.90",
@ -5540,14 +5645,14 @@ dependencies = [
[[package]]
name = "regex"
version = "1.10.5"
version = "1.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f"
checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191"
dependencies = [
"aho-corasick",
"memchr",
"regex-automata 0.4.7",
"regex-syntax 0.8.4",
"regex-automata 0.4.9",
"regex-syntax 0.8.5",
]
[[package]]
@ -5561,13 +5666,13 @@ dependencies = [
[[package]]
name = "regex-automata"
version = "0.4.7"
version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax 0.8.4",
"regex-syntax 0.8.5",
]
[[package]]
@ -5584,9 +5689,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
[[package]]
name = "regex-syntax"
version = "0.8.4"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "rend"
@ -5599,9 +5704,9 @@ dependencies = [
[[package]]
name = "reqwest"
version = "0.12.9"
version = "0.12.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a77c62af46e79de0a562e1a9849205ffcb7fc1238876e9bd743357570e04046f"
checksum = "d19c46a6fdd48bc4dab94b6103fccc55d34c67cc0ad04653aad4ea2a07cd7bbb"
dependencies = [
"base64 0.22.1",
"bytes",
@ -5629,6 +5734,7 @@ dependencies = [
"pin-project-lite",
"quinn",
"rustls 0.23.20",
"rustls-native-certs 0.8.0",
"rustls-pemfile 2.1.2",
"rustls-pki-types",
"serde",
@ -5640,6 +5746,7 @@ dependencies = [
"tokio-native-tls",
"tokio-rustls 0.26.0",
"tokio-util",
"tower 0.5.2",
"tower-service",
"url",
"wasm-bindgen",
@ -5650,6 +5757,22 @@ dependencies = [
"windows-registry",
]
[[package]]
name = "reqwest-eventsource"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde"
dependencies = [
"eventsource-stream",
"futures-core",
"futures-timer",
"mime",
"nom",
"pin-project-lite",
"reqwest",
"thiserror 1.0.63",
]
[[package]]
name = "rfc6979"
version = "0.3.1"
@ -5868,6 +5991,32 @@ dependencies = [
"security-framework",
]
[[package]]
name = "rustls-native-certs"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5bfb394eeed242e909609f56089eecfe5fda225042e8b171791b9c95f5931e5"
dependencies = [
"openssl-probe",
"rustls-pemfile 2.1.2",
"rustls-pki-types",
"schannel",
"security-framework",
]
[[package]]
name = "rustls-native-certs"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcaf18a4f2be7326cd874a5fa579fae794320a0f388d365dca7e480e55f83f8a"
dependencies = [
"openssl-probe",
"rustls-pemfile 2.1.2",
"rustls-pki-types",
"schannel",
"security-framework",
]
[[package]]
name = "rustls-pemfile"
version = "1.0.4"
@ -6039,6 +6188,16 @@ dependencies = [
"zeroize",
]
[[package]]
name = "secrecy"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a"
dependencies = [
"serde",
"zeroize",
]
[[package]]
name = "security-framework"
version = "2.11.1"
@ -6092,18 +6251,18 @@ dependencies = [
[[package]]
name = "serde"
version = "1.0.204"
version = "1.0.219"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc76f558e0cbb2a839d37354c575f1dc3fdc6546b5be373ba43d95f231bf7c12"
checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.204"
version = "1.0.219"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222"
checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00"
dependencies = [
"proc-macro2",
"quote",
@ -6123,9 +6282,9 @@ dependencies = [
[[package]]
name = "serde_json"
version = "1.0.128"
version = "1.0.140"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8"
checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373"
dependencies = [
"itoa",
"memchr",
@ -6858,11 +7017,11 @@ dependencies = [
[[package]]
name = "thiserror"
version = "2.0.9"
version = "2.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f072643fd0190df67a8bab670c20ef5d8737177d6ac6b2e9a236cb096206b2cc"
checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708"
dependencies = [
"thiserror-impl 2.0.9",
"thiserror-impl 2.0.12",
]
[[package]]
@ -6878,9 +7037,9 @@ dependencies = [
[[package]]
name = "thiserror-impl"
version = "2.0.9"
version = "2.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b50fa271071aae2e6ee85f842e2e28ba8cd2c5fb67f11fcb1fd70b276f9e7d4"
checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d"
dependencies = [
"proc-macro2",
"quote",
@ -6991,9 +7150,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokio"
version = "1.39.2"
version = "1.44.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "daa4fb1bc778bd6f04cbfc4bb2d06a7396a8f299dc33ea1900cedaa316f467b1"
checksum = "e6b88822cbe49de4185e3a4cbf8321dd487cf5fe0c5c65695fef6346371e9c48"
dependencies = [
"backtrace",
"bytes",
@ -7010,9 +7169,9 @@ dependencies = [
[[package]]
name = "tokio-macros"
version = "2.4.0"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752"
checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8"
dependencies = [
"proc-macro2",
"quote",
@ -7075,9 +7234,9 @@ dependencies = [
[[package]]
name = "tokio-stream"
version = "0.1.16"
version = "0.1.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4f4e6ce100d0eb49a2734f8c0812bcd324cf357d21810932c5df6b96ef2b86f1"
checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047"
dependencies = [
"futures-core",
"pin-project-lite",
@ -7095,7 +7254,7 @@ dependencies = [
"log",
"native-tls",
"rustls 0.21.12",
"rustls-native-certs",
"rustls-native-certs 0.6.3",
"tokio",
"tokio-native-tls",
"tokio-rustls 0.24.1",
@ -7119,9 +7278,9 @@ dependencies = [
[[package]]
name = "tokio-util"
version = "0.7.12"
version = "0.7.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a"
checksum = "6b9590b93e6fcc1739458317cccd391ad3955e2bde8913edf6f95f9e65a8f034"
dependencies = [
"bytes",
"futures-core",
@ -7172,7 +7331,7 @@ dependencies = [
"socket2",
"tokio",
"tokio-stream",
"tower",
"tower 0.4.13",
"tower-layer",
"tower-service",
"tracing",
@ -7221,6 +7380,21 @@ dependencies = [
"tracing",
]
[[package]]
name = "tower"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9"
dependencies = [
"futures-core",
"futures-util",
"pin-project-lite",
"sync_wrapper 1.0.1",
"tokio",
"tower-layer",
"tower-service",
]
[[package]]
name = "tower-http"
version = "0.5.2"
@ -7248,15 +7422,15 @@ dependencies = [
[[package]]
name = "tower-layer"
version = "0.3.2"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0"
checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e"
[[package]]
name = "tower-service"
version = "0.3.2"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52"
checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3"
[[package]]
name = "tracing"
@ -7406,7 +7580,7 @@ dependencies = [
"native-tls",
"rand 0.8.5",
"sha1",
"thiserror 2.0.9",
"thiserror 2.0.12",
"utf-8",
]
@ -7670,23 +7844,24 @@ checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b"
[[package]]
name = "wasm-bindgen"
version = "0.2.92"
version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8"
checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5"
dependencies = [
"cfg-if",
"once_cell",
"rustversion",
"wasm-bindgen-macro",
]
[[package]]
name = "wasm-bindgen-backend"
version = "0.2.92"
version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da"
checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6"
dependencies = [
"bumpalo",
"log",
"once_cell",
"proc-macro2",
"quote",
"syn 2.0.90",
@ -7707,9 +7882,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.92"
version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726"
checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407"
dependencies = [
"quote",
"wasm-bindgen-macro-support",
@ -7717,9 +7892,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen-macro-support"
version = "0.2.92"
version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7"
checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de"
dependencies = [
"proc-macro2",
"quote",
@ -7730,9 +7905,12 @@ dependencies = [
[[package]]
name = "wasm-bindgen-shared"
version = "0.2.92"
version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96"
checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d"
dependencies = [
"unicode-ident",
]
[[package]]
name = "wasm-bindgen-test"
@ -7890,33 +8068,38 @@ dependencies = [
]
[[package]]
name = "windows-registry"
version = "0.2.0"
name = "windows-link"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0"
checksum = "76840935b766e1b0a05c0066835fb9ec80071d4c09a16f6bd5f7e655e3c14c38"
[[package]]
name = "windows-registry"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4286ad90ddb45071efd1a66dfa43eb02dd0dfbae1545ad6cc3c51cf34d7e8ba3"
dependencies = [
"windows-result",
"windows-strings",
"windows-targets 0.52.6",
"windows-targets 0.53.0",
]
[[package]]
name = "windows-result"
version = "0.2.0"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e"
checksum = "c64fd11a4fd95df68efcfee5f44a294fe71b8bc6a91993e2791938abcc712252"
dependencies = [
"windows-targets 0.52.6",
"windows-link",
]
[[package]]
name = "windows-strings"
version = "0.1.0"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10"
checksum = "87fa48cc5d406560701792be122a10132491cff9d0aeb23583cc2dcafc847319"
dependencies = [
"windows-result",
"windows-targets 0.52.6",
"windows-link",
]
[[package]]
@ -7961,13 +8144,29 @@ dependencies = [
"windows_aarch64_gnullvm 0.52.6",
"windows_aarch64_msvc 0.52.6",
"windows_i686_gnu 0.52.6",
"windows_i686_gnullvm",
"windows_i686_gnullvm 0.52.6",
"windows_i686_msvc 0.52.6",
"windows_x86_64_gnu 0.52.6",
"windows_x86_64_gnullvm 0.52.6",
"windows_x86_64_msvc 0.52.6",
]
[[package]]
name = "windows-targets"
version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1e4c7e8ceaaf9cb7d7507c974735728ab453b67ef8f18febdd7c11fe59dca8b"
dependencies = [
"windows_aarch64_gnullvm 0.53.0",
"windows_aarch64_msvc 0.53.0",
"windows_i686_gnu 0.53.0",
"windows_i686_gnullvm 0.53.0",
"windows_i686_msvc 0.53.0",
"windows_x86_64_gnu 0.53.0",
"windows_x86_64_gnullvm 0.53.0",
"windows_x86_64_msvc 0.53.0",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.48.5"
@ -7980,6 +8179,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764"
[[package]]
name = "windows_aarch64_msvc"
version = "0.48.5"
@ -7992,6 +8197,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
[[package]]
name = "windows_aarch64_msvc"
version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c"
[[package]]
name = "windows_i686_gnu"
version = "0.48.5"
@ -8004,12 +8215,24 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
[[package]]
name = "windows_i686_gnu"
version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3"
[[package]]
name = "windows_i686_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
[[package]]
name = "windows_i686_gnullvm"
version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11"
[[package]]
name = "windows_i686_msvc"
version = "0.48.5"
@ -8022,6 +8245,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
[[package]]
name = "windows_i686_msvc"
version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d"
[[package]]
name = "windows_x86_64_gnu"
version = "0.48.5"
@ -8034,6 +8263,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
[[package]]
name = "windows_x86_64_gnu"
version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.48.5"
@ -8046,6 +8281,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57"
[[package]]
name = "windows_x86_64_msvc"
version = "0.48.5"
@ -8058,6 +8299,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "windows_x86_64_msvc"
version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486"
[[package]]
name = "winnow"
version = "0.5.40"

View file

@ -303,13 +303,13 @@ lto = false
[patch.crates-io]
# It's diffcult to resovle different version with the same crate used in AppFlowy Frontend and the Client-API crate.
# So using patch to workaround this issue.
collab = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" }
collab-entity = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" }
collab-folder = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" }
collab-document = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" }
collab-user = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" }
collab-database = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" }
collab-importer = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" }
collab = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a0e2cc07f50f17d1b6605c579e622a431e94998" }
collab-entity = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a0e2cc07f50f17d1b6605c579e622a431e94998" }
collab-folder = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a0e2cc07f50f17d1b6605c579e622a431e94998" }
collab-document = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a0e2cc07f50f17d1b6605c579e622a431e94998" }
collab-user = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a0e2cc07f50f17d1b6605c579e622a431e94998" }
collab-database = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a0e2cc07f50f17d1b6605c579e622a431e94998" }
collab-importer = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a0e2cc07f50f17d1b6605c579e622a431e94998" }
[features]
history = []

View file

@ -220,69 +220,6 @@ pub struct TranslateRowResponse {
pub items: Vec<HashMap<String, String>>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(untagged)]
pub enum EmbeddingInput {
/// The string that will be turned into an embedding.
String(String),
/// The array of strings that will be turned into an embedding.
StringArray(Vec<String>),
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(untagged)]
pub enum EmbeddingOutput {
Float(Vec<f64>),
Base64(String),
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Embedding {
/// An integer representing the index of the embedding in the list of embeddings.
pub index: i32,
/// The embedding value, which is an instance of `EmbeddingOutput`.
pub embedding: EmbeddingOutput,
}
/// https://platform.openai.com/docs/api-reference/embeddings
#[derive(Serialize, Deserialize, Debug)]
pub struct OpenAIEmbeddingResponse {
/// A string that is always set to "embedding".
pub object: String,
/// A list of `Embedding` objects.
pub data: Vec<Embedding>,
/// A string representing the model used to generate the embeddings.
pub model: String,
/// An integer representing the total number of tokens used.
pub usage: EmbeddingUsage,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct EmbeddingUsage {
#[serde(default)]
pub prompt_tokens: i32,
pub total_tokens: i32,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "lowercase")]
pub enum EmbeddingEncodingFormat {
Float,
Base64,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct EmbeddingRequest {
/// An instance of `EmbeddingInput` containing the data to be embedded.
pub input: EmbeddingInput,
/// A string representing the model to use for generating embeddings.
pub model: String,
/// An instance of `EmbeddingEncodingFormat` representing the format of the embedding.
pub encoding_format: EmbeddingEncodingFormat,
/// An integer representing the number of dimensions for the embedding.
pub dimensions: i32,
}
#[derive(Serialize, Deserialize, Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum EmbeddingModel {
#[serde(rename = "text-embedding-3-small")]
@ -310,7 +247,7 @@ impl EmbeddingModel {
}
}
pub fn default_dimensions(&self) -> i32 {
pub fn default_dimensions(&self) -> u32 {
match self {
EmbeddingModel::TextEmbeddingAda002 => 1536,
EmbeddingModel::TextEmbedding3Large => 3072,

View file

@ -35,3 +35,4 @@ redis = { workspace = true, features = [
secrecy = { workspace = true, features = ["serde"] }
reqwest.workspace = true
twox-hash = { version = "2.1.0", features = ["xxhash64"] }
async-openai = "0.28.0"

View file

@ -1,11 +1,10 @@
use crate::collab_indexer::Indexer;
use crate::vector::embedder::Embedder;
use crate::vector::embedder::AFEmbedder;
use crate::vector::open_ai::group_paragraphs_by_max_content_len;
use anyhow::anyhow;
use app_error::AppError;
use appflowy_ai_client::dto::{
EmbeddingEncodingFormat, EmbeddingInput, EmbeddingModel, EmbeddingOutput, EmbeddingRequest,
};
use appflowy_ai_client::dto::EmbeddingModel;
use async_openai::types::{CreateEmbeddingRequestArgs, EmbeddingInput, EncodingFormat};
use async_trait::async_trait;
use collab::preclude::Collab;
use collab_document::document::DocumentBody;
@ -48,7 +47,7 @@ impl Indexer for DocumentIndexer {
async fn embed(
&self,
embedder: &Embedder,
embedder: &AFEmbedder,
mut content: Vec<AFCollabEmbeddedChunk>,
) -> Result<Option<AFCollabEmbeddings>, AppError> {
if content.is_empty() {
@ -59,14 +58,16 @@ impl Indexer for DocumentIndexer {
.iter()
.map(|fragment| fragment.content.clone().unwrap_or_default())
.collect();
let resp = embedder
.async_embed(EmbeddingRequest {
input: EmbeddingInput::StringArray(contents),
model: embedder.model().name().to_string(),
encoding_format: EmbeddingEncodingFormat::Float,
dimensions: EmbeddingModel::TextEmbedding3Small.default_dimensions(),
})
.await?;
let request = CreateEmbeddingRequestArgs::default()
.model(embedder.model().name())
.input(EmbeddingInput::StringArray(contents))
.encoding_format(EncodingFormat::Float)
.dimensions(EmbeddingModel::TextEmbedding3Small.default_dimensions())
.build()
.map_err(|err| AppError::Unhandled(err.to_string()))?;
let resp = embedder.async_embed(request).await?;
trace!(
"[Embedding] request {} embeddings, received {} embeddings",
@ -77,21 +78,12 @@ impl Indexer for DocumentIndexer {
for embedding in resp.data {
let param = &mut content[embedding.index as usize];
if param.content.is_some() {
// we only set the embedding if the content was not marked as unchanged
let embedding: Vec<f32> = match embedding.embedding {
EmbeddingOutput::Float(embedding) => embedding.into_iter().map(|f| f as f32).collect(),
EmbeddingOutput::Base64(_) => {
return Err(AppError::OpenError(
"Unexpected base64 encoding".to_string(),
))
},
};
param.embedding = Some(embedding);
param.embedding = Some(embedding.embedding);
}
}
Ok(Some(AFCollabEmbeddings {
tokens_consumed: resp.usage.total_tokens as u32,
tokens_consumed: resp.usage.total_tokens,
params: content,
}))
}

View file

@ -1,5 +1,5 @@
use crate::collab_indexer::DocumentIndexer;
use crate::vector::embedder::Embedder;
use crate::vector::embedder::AFEmbedder;
use app_error::AppError;
use appflowy_ai_client::dto::EmbeddingModel;
use async_trait::async_trait;
@ -29,7 +29,7 @@ pub trait Indexer: Send + Sync {
async fn embed(
&self,
embedder: &Embedder,
embedder: &AFEmbedder,
content: Vec<AFCollabEmbeddedChunk>,
) -> Result<Option<AFCollabEmbeddings>, AppError>;
}

View file

@ -2,10 +2,11 @@ use crate::collab_indexer::IndexerProvider;
use crate::entity::EmbeddingRecord;
use crate::metrics::EmbeddingMetrics;
use crate::queue::add_background_embed_task;
use crate::vector::embedder::Embedder;
use crate::vector::embedder::AFEmbedder;
use crate::vector::open_ai;
use app_error::AppError;
use appflowy_ai_client::dto::{EmbeddingRequest, OpenAIEmbeddingResponse};
use async_openai::config::{AzureConfig, OpenAIConfig};
use async_openai::types::{CreateEmbeddingRequest, CreateEmbeddingResponse};
use collab::preclude::Collab;
use collab_document::document::DocumentBody;
use collab_entity::CollabType;
@ -16,7 +17,6 @@ use database::index::{
use database::workspace::select_workspace_settings;
use infra::env_util::get_env_var;
use redis::aio::ConnectionManager;
use secrecy::{ExposeSecret, Secret};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use std::cmp::max;
@ -48,7 +48,8 @@ pub struct IndexerScheduler {
#[derive(Debug)]
pub struct IndexerConfiguration {
pub enable: bool,
pub openai_api_key: Secret<String>,
pub open_ai_config: Option<OpenAIConfig>,
pub azure_ai_config: Option<AzureConfig>,
/// High watermark for the number of embeddings that can be buffered before being written to the database.
pub embedding_buffer_size: usize,
}
@ -114,39 +115,36 @@ impl IndexerScheduler {
}
fn index_enabled(&self) -> bool {
// if indexing is disabled, return false
if !self.config.enable {
return false;
}
// if openai api key is empty, return false
if self.config.openai_api_key.expose_secret().is_empty() {
return false;
}
true
self.config.enable
&& (self.config.open_ai_config.is_some() || self.config.azure_ai_config.is_some())
}
pub fn is_indexing_enabled(&self, collab_type: CollabType) -> bool {
self.indexer_provider.is_indexing_enabled(collab_type)
}
pub(crate) fn create_embedder(&self) -> Result<Embedder, AppError> {
if self.config.openai_api_key.expose_secret().is_empty() {
return Err(AppError::AIServiceUnavailable(
"OpenAI API key is empty".to_string(),
));
pub(crate) fn create_embedder(&self) -> Result<AFEmbedder, AppError> {
if let Some(config) = &self.config.azure_ai_config {
return Ok(AFEmbedder::AzureOpenAI(open_ai::AzureOpenAIEmbedder::new(
config.clone(),
)));
}
Ok(Embedder::OpenAI(open_ai::Embedder::new(
self.config.openai_api_key.expose_secret().clone(),
)))
if let Some(config) = &self.config.open_ai_config {
return Ok(AFEmbedder::OpenAI(open_ai::OpenAIEmbedder::new(
config.clone(),
)));
}
Err(AppError::AIServiceUnavailable(
"No embedder available".to_string(),
))
}
pub async fn create_search_embeddings(
&self,
request: EmbeddingRequest,
) -> Result<OpenAIEmbeddingResponse, AppError> {
request: CreateEmbeddingRequest,
) -> Result<CreateEmbeddingResponse, AppError> {
let embedder = self.create_embedder()?;
let embeddings = embedder.async_embed(request).await?;
Ok(embeddings)
@ -327,14 +325,12 @@ async fn generate_embeddings_loop(
match embedder {
Ok(embedder) => {
let params: Vec<_> = records.iter().map(|r| r.object_id).collect();
let existing_embeddings =
match get_collab_embedding_fragment_ids(&scheduler.pg_pool, params).await {
Ok(existing_embeddings) => existing_embeddings,
Err(err) => {
let existing_embeddings = get_collab_embedding_fragment_ids(&scheduler.pg_pool, params)
.await
.unwrap_or_else(|err| {
error!("[Embedding] failed to get existing embeddings: {}", err);
Default::default()
},
};
});
let mut join_set = JoinSet::new();
for record in records {
if let Some(indexer) = indexer_provider.indexer_for(record.collab_type) {

View file

@ -1,7 +1,7 @@
use crate::collab_indexer::IndexerProvider;
use crate::entity::{EmbeddingRecord, UnindexedCollab};
use crate::scheduler::{batch_insert_records, IndexerScheduler};
use crate::vector::embedder::Embedder;
use crate::vector::embedder::AFEmbedder;
use appflowy_ai_client::dto::EmbeddingModel;
use collab::core::collab::DataSource;
use collab::core::origin::CollabOrigin;
@ -157,7 +157,7 @@ async fn stream_unindexed_collabs(
.boxed()
}
async fn create_embeddings(
embedder: Embedder,
embedder: AFEmbedder,
indexer_provider: &Arc<IndexerProvider>,
unindexed_records: Vec<UnindexedCollab>,
existing_embeddings: HashMap<Uuid, Vec<String>>,

View file

@ -1,24 +1,28 @@
use crate::vector::open_ai;
use crate::vector::open_ai::async_embed;
use app_error::AppError;
use appflowy_ai_client::dto::{EmbeddingModel, EmbeddingRequest, OpenAIEmbeddingResponse};
use appflowy_ai_client::dto::EmbeddingModel;
pub use async_openai::config::{AzureConfig, OpenAIConfig};
pub use async_openai::types::{
CreateEmbeddingRequest, CreateEmbeddingRequestArgs, CreateEmbeddingResponse, EmbeddingInput,
EncodingFormat,
};
use infra::env_util::get_env_var_opt;
#[derive(Debug, Clone)]
pub enum Embedder {
OpenAI(open_ai::Embedder),
pub enum AFEmbedder {
OpenAI(open_ai::OpenAIEmbedder),
AzureOpenAI(open_ai::AzureOpenAIEmbedder),
}
impl Embedder {
pub fn embed(&self, params: EmbeddingRequest) -> Result<OpenAIEmbeddingResponse, AppError> {
match self {
Self::OpenAI(embedder) => embedder.embed(params),
}
}
impl AFEmbedder {
pub async fn async_embed(
&self,
params: EmbeddingRequest,
) -> Result<OpenAIEmbeddingResponse, AppError> {
params: CreateEmbeddingRequest,
) -> Result<CreateEmbeddingResponse, AppError> {
match self {
Self::OpenAI(embedder) => embedder.async_embed(params).await,
Self::OpenAI(embedder) => async_embed(&embedder.client, params).await,
Self::AzureOpenAI(embedder) => async_embed(&embedder.client, params).await,
}
}
@ -26,3 +30,20 @@ impl Embedder {
EmbeddingModel::TextEmbedding3Small
}
}
pub fn open_ai_config() -> Option<OpenAIConfig> {
get_env_var_opt("AI_OPENAI_API_KEY").map(|v| OpenAIConfig::default().with_api_key(v))
}
pub fn azure_open_ai_config() -> Option<AzureConfig> {
let azure_open_ai_api_key = get_env_var_opt("AI_AZURE_OPENAI_API_KEY")?;
let azure_open_ai_api_base = get_env_var_opt("AI_AZURE_OPENAI_API_BASE")?;
let azure_open_ai_api_version = get_env_var_opt("AI_AZURE_OPENAI_API_VERSION")?;
Some(
AzureConfig::new()
.with_api_key(azure_open_ai_api_key)
.with_api_base(azure_open_ai_api_base)
.with_api_version(azure_open_ai_api_version),
)
}

View file

@ -1,3 +1,2 @@
pub mod embedder;
pub mod open_ai;
mod rest;

View file

@ -1,9 +1,7 @@
use crate::vector::rest::check_ureq_response;
use anyhow::anyhow;
use app_error::AppError;
use appflowy_ai_client::dto::{EmbeddingRequest, OpenAIEmbeddingResponse};
use serde::de::DeserializeOwned;
use std::time::Duration;
use async_openai::config::{AzureConfig, Config, OpenAIConfig};
use async_openai::types::{CreateEmbeddingRequest, CreateEmbeddingResponse};
use async_openai::Client;
use tiktoken_rs::CoreBPE;
pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
@ -11,98 +9,42 @@ pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
pub const REQUEST_PARALLELISM: usize = 40;
#[derive(Debug, Clone)]
pub struct Embedder {
bearer: String,
sync_client: ureq::Agent,
async_client: reqwest::Client,
pub struct OpenAIEmbedder {
pub(crate) client: Client<OpenAIConfig>,
}
impl Embedder {
pub fn new(api_key: String) -> Self {
let bearer = format!("Bearer {api_key}");
let sync_client = ureq::AgentBuilder::new()
.max_idle_connections(REQUEST_PARALLELISM * 2)
.max_idle_connections_per_host(REQUEST_PARALLELISM * 2)
.build();
impl OpenAIEmbedder {
pub fn new(config: OpenAIConfig) -> Self {
let client = Client::with_config(config);
let async_client = reqwest::Client::builder().build().unwrap();
Self {
bearer,
sync_client,
async_client,
Self { client }
}
}
pub fn embed(&self, params: EmbeddingRequest) -> Result<OpenAIEmbeddingResponse, AppError> {
for attempt in 0..3 {
let request = self
.sync_client
.post(OPENAI_EMBEDDINGS_URL)
.set("Authorization", &self.bearer)
.set("Content-Type", "application/json");
let result = check_ureq_response(request.send_json(&params));
let retry_duration = match result {
Ok(response) => {
let data = from_ureq_response::<OpenAIEmbeddingResponse>(response)?;
return Ok(data);
},
Err(retry) => retry.into_duration(attempt),
}
.map_err(|err| AppError::Internal(err.into()))?;
let retry_duration = retry_duration.min(Duration::from_secs(10));
std::thread::sleep(retry_duration);
#[derive(Debug, Clone)]
pub struct AzureOpenAIEmbedder {
pub(crate) client: Client<AzureConfig>,
}
Err(AppError::Internal(anyhow!(
"Failed to generate embeddings after 3 attempts"
)))
impl AzureOpenAIEmbedder {
pub fn new(config: AzureConfig) -> Self {
let client = Client::with_config(config);
Self { client }
}
}
pub async fn async_embed(
&self,
params: EmbeddingRequest,
) -> Result<OpenAIEmbeddingResponse, AppError> {
let request = self
.async_client
.post(OPENAI_EMBEDDINGS_URL)
.header("Authorization", &self.bearer)
.header("Content-Type", "application/json");
let result = request.json(&params).send().await?;
let response = from_response::<OpenAIEmbeddingResponse>(result).await?;
pub async fn async_embed<C: Config>(
client: &Client<C>,
request: CreateEmbeddingRequest,
) -> Result<CreateEmbeddingResponse, AppError> {
let response = client
.embeddings()
.create(request)
.await
.map_err(|err| AppError::Unhandled(err.to_string()))?;
Ok(response)
}
}
pub fn from_ureq_response<T>(resp: ureq::Response) -> Result<T, anyhow::Error>
where
T: DeserializeOwned,
{
let status_code = resp.status();
if status_code != 200 {
let body = resp.into_string()?;
anyhow::bail!("error code: {}, {}", status_code, body)
}
let resp = resp.into_json()?;
Ok(resp)
}
pub async fn from_response<T>(resp: reqwest::Response) -> Result<T, anyhow::Error>
where
T: DeserializeOwned,
{
let status_code = resp.status();
if status_code != 200 {
let body = resp.text().await?;
anyhow::bail!("error code: {}, {}", status_code, body)
}
let resp = resp.json().await?;
Ok(resp)
}
/// ## Execution Time Comparison Results
///
/// The following results were observed when running `execution_time_comparison_tests`:

View file

@ -1,189 +0,0 @@
use crate::thread_pool::CatchedPanic;
#[derive(Debug, thiserror::Error)]
#[error("{fault}: {kind}")]
pub struct EmbedError {
pub kind: EmbedErrorKind,
pub fault: FaultSource,
}
impl EmbedError {
pub(crate) fn rest_unauthorized(error_response: Option<String>) -> EmbedError {
Self {
kind: EmbedErrorKind::RestUnauthorized(error_response),
fault: FaultSource::User,
}
}
pub(crate) fn rest_too_many_requests(error_response: Option<String>) -> EmbedError {
Self {
kind: EmbedErrorKind::RestTooManyRequests(error_response),
fault: FaultSource::Runtime,
}
}
pub(crate) fn rest_bad_request(error_response: Option<String>) -> EmbedError {
Self {
kind: EmbedErrorKind::RestBadRequest(error_response),
fault: FaultSource::User,
}
}
pub(crate) fn rest_internal_server_error(
code: u16,
error_response: Option<String>,
) -> EmbedError {
Self {
kind: EmbedErrorKind::RestInternalServerError(code, error_response),
fault: FaultSource::Runtime,
}
}
pub(crate) fn rest_other_status_code(code: u16, error_response: Option<String>) -> EmbedError {
Self {
kind: EmbedErrorKind::RestOtherStatusCode(code, error_response),
fault: FaultSource::Unhandled,
}
}
pub(crate) fn rest_network(transport: ureq::Transport) -> EmbedError {
Self {
kind: EmbedErrorKind::RestNetwork(transport),
fault: FaultSource::Runtime,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum EmbedErrorKind {
#[error("could not authenticate against {}", option_info(.0.as_deref(), "server replied with "))]
RestUnauthorized(Option<String>),
#[error("sent too many requests to embedding server{}", option_info(.0.as_deref(), "server replied with "))]
RestTooManyRequests(Option<String>),
#[error("received bad request HTTP from embedding server{}", option_info(.0.as_deref(), "server replied with "))]
RestBadRequest(Option<String>),
#[error("received internal error HTTP {0} from embedding server{}", option_info(.1.as_deref(), "server replied with "))]
RestInternalServerError(u16, Option<String>),
#[error("received unexpected HTTP {0} from embedding server{}", option_info(.1.as_deref(), "server replied with "))]
RestOtherStatusCode(u16, Option<String>),
#[error("could not reach embedding server:\n - {0}")]
RestNetwork(ureq::Transport),
#[error(transparent)]
PanicInThreadPool(#[from] CatchedPanic),
}
fn option_info(info: Option<&str>, prefix: &str) -> String {
match info {
Some(info) => format!("\n - {prefix}`{info}`"),
None => String::new(),
}
}
#[derive(Debug, Clone, Copy)]
pub enum FaultSource {
User,
Runtime,
Unhandled,
}
impl std::fmt::Display for FaultSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
FaultSource::User => "user error",
FaultSource::Runtime => "runtime error",
FaultSource::Unhandled => "error",
};
f.write_str(s)
}
}
pub struct Retry {
pub error: EmbedError,
strategy: RetryStrategy,
}
pub enum RetryStrategy {
GiveUp,
Retry,
RetryAfterRateLimit,
}
impl Retry {
pub fn give_up(error: EmbedError) -> Self {
Self {
error,
strategy: RetryStrategy::GiveUp,
}
}
pub fn retry_later(error: EmbedError) -> Self {
Self {
error,
strategy: RetryStrategy::Retry,
}
}
pub fn rate_limited(error: EmbedError) -> Self {
Self {
error,
strategy: RetryStrategy::RetryAfterRateLimit,
}
}
/// Converts a retry strategy into a delay duration based on the number of retry attempts.
///
/// # Retry Strategies
///
/// - **GiveUp**: If the retry strategy is `GiveUp`, the function immediately returns an error,
/// indicating that further retries should not be attempted. The error is provided by `self.error`.
///
/// - **Retry**: If the retry strategy is `Retry`, the function applies exponential backoff based
/// on the attempt number. The delay increases exponentially with each retry attempt (e.g.,
/// `10^attempt` milliseconds). For the first retry, the delay will be 10 milliseconds,
/// for the second retry 100 milliseconds, and so on.
///
/// - **RetryAfterRateLimit**: If the retry strategy is `RetryAfterRateLimit`, the function adds
/// a fixed delay of 100 milliseconds to the exponential backoff calculation. For example,
/// for the first retry, the delay will be `100 + 10^1 = 110` milliseconds, and for the second
/// retry, it will be `100 + 10^2 = 200` milliseconds.
pub fn into_duration(self, attempt: u32) -> Result<std::time::Duration, Box<EmbedError>> {
match self.strategy {
RetryStrategy::GiveUp => Err(Box::new(self.error)),
RetryStrategy::Retry => Ok(std::time::Duration::from_millis((10u64).pow(attempt))),
RetryStrategy::RetryAfterRateLimit => {
Ok(std::time::Duration::from_millis(100 + 10u64.pow(attempt)))
},
}
}
}
#[allow(clippy::result_large_err)]
pub(crate) fn check_ureq_response(
response: Result<ureq::Response, ureq::Error>,
) -> Result<ureq::Response, Retry> {
match response {
Ok(response) => Ok(response),
Err(ureq::Error::Status(code, response)) => {
let error_response: Option<String> = response.into_string().ok();
Err(match code {
401 => Retry::give_up(EmbedError::rest_unauthorized(error_response)),
429 => Retry::rate_limited(EmbedError::rest_too_many_requests(error_response)),
400 => Retry::give_up(EmbedError::rest_bad_request(error_response)),
500..=599 => {
Retry::retry_later(EmbedError::rest_internal_server_error(code, error_response))
},
402..=499 => Retry::give_up(EmbedError::rest_other_status_code(code, error_response)),
_ => Retry::retry_later(EmbedError::rest_other_status_code(code, error_response)),
})
},
Err(ureq::Error::Transport(transport)) => {
Err(Retry::retry_later(EmbedError::rest_network(transport)))
},
}
}

View file

@ -21,9 +21,10 @@ use axum::response::IntoResponse;
use axum::routing::get;
use indexer::metrics::EmbeddingMetrics;
use indexer::thread_pool::ThreadPoolNoAbortBuilder;
use indexer::vector::embedder::{azure_open_ai_config, open_ai_config};
use infra::env_util::get_env_var;
use mailer::sender::Mailer;
use secrecy::{ExposeSecret, Secret};
use secrecy::ExposeSecret;
use std::sync::{Arc, Once};
use std::time::Duration;
use tokio::net::TcpListener;
@ -131,21 +132,24 @@ pub async fn create_app(listener: TcpListener, config: Config) -> Result<(), Err
.unwrap(),
);
let open_ai_config = open_ai_config();
let azure_ai_config = azure_open_ai_config();
let indexer_config = BackgroundIndexerConfig {
enable: appflowy_collaborate::config::get_env_var("APPFLOWY_INDEXER_ENABLED", "true")
.parse::<bool>()
.unwrap_or(true),
open_ai_config,
azure_ai_config,
tick_interval_secs: 10,
};
tokio::spawn(run_background_indexer(
state.pg_pool.clone(),
state.redis_client.clone(),
state.metrics.embedder_metrics.clone(),
threads.clone(),
BackgroundIndexerConfig {
enable: appflowy_collaborate::config::get_env_var("APPFLOWY_INDEXER_ENABLED", "true")
.parse::<bool>()
.unwrap_or(true),
open_api_key: Secret::new(appflowy_collaborate::config::get_env_var(
"AI_OPENAI_API_KEY",
"",
)),
tick_interval_secs: 10,
},
indexer_config,
));
let app = Router::new()

View file

@ -10,10 +10,9 @@ use indexer::queue::{
};
use indexer::scheduler::{spawn_pg_write_embeddings, UnindexedCollabTask, UnindexedData};
use indexer::thread_pool::ThreadPoolNoAbort;
use indexer::vector::embedder::Embedder;
use indexer::vector::embedder::{AFEmbedder, AzureConfig, OpenAIConfig};
use indexer::vector::open_ai;
use redis::aio::ConnectionManager;
use secrecy::{ExposeSecret, Secret};
use sqlx::PgPool;
use std::sync::Arc;
use std::time::{Duration, Instant};
@ -25,7 +24,8 @@ use tracing::{error, info, trace, warn};
pub struct BackgroundIndexerConfig {
pub enable: bool,
pub open_api_key: Secret<String>,
pub open_ai_config: Option<OpenAIConfig>,
pub azure_ai_config: Option<AzureConfig>,
pub tick_interval_secs: u64,
}
@ -41,7 +41,7 @@ pub async fn run_background_indexer(
return;
}
if config.open_api_key.expose_secret().is_empty() {
if config.open_ai_config.is_none() && config.azure_ai_config.is_none() {
error!("OpenAI API key is not set. Stop background indexer");
return;
}
@ -160,7 +160,7 @@ async fn process_upcoming_tasks(
let mut join_set = JoinSet::new();
for task in tasks {
if let Some(indexer) = indexer_provider.indexer_for(task.collab_type) {
let embedder = create_embedder(&config);
if let Ok(embedder) = create_embedder(&config) {
trace!(
"[Background Embedding] processing task: {}, content:{:?}, collab_type: {}",
task.object_id,
@ -206,6 +206,7 @@ async fn process_upcoming_tasks(
});
}
}
}
while let Some(Ok(result)) = join_set.join_next().await {
match result {
@ -254,8 +255,20 @@ async fn process_upcoming_tasks(
}
}
fn create_embedder(config: &BackgroundIndexerConfig) -> Embedder {
Embedder::OpenAI(open_ai::Embedder::new(
config.open_api_key.expose_secret().clone(),
fn create_embedder(config: &BackgroundIndexerConfig) -> Result<AFEmbedder, AppError> {
if let Some(config) = &config.azure_ai_config {
return Ok(AFEmbedder::AzureOpenAI(open_ai::AzureOpenAIEmbedder::new(
config.clone(),
)));
}
if let Some(config) = &config.open_ai_config {
return Ok(AFEmbedder::OpenAI(open_ai::OpenAIEmbedder::new(
config.clone(),
)));
}
Err(AppError::AIServiceUnavailable(
"No embedder available".to_string(),
))
}

View file

@ -27,7 +27,7 @@ use aws_sdk_s3::types::{
BucketInfo, BucketLocationConstraint, BucketType, CreateBucketConfiguration,
};
use mailer::config::MailerSetting;
use secrecy::{ExposeSecret, Secret};
use secrecy::ExposeSecret;
use sqlx::{postgres::PgPoolOptions, PgPool};
use tokio::sync::RwLock;
use tracing::{error, info};
@ -45,6 +45,7 @@ use collab_stream::stream_router::{StreamRouter, StreamRouterOptions};
use database::file::s3_client_impl::{AwsS3BucketClientImpl, S3BucketStorage};
use indexer::collab_indexer::IndexerProvider;
use indexer::scheduler::{IndexerConfiguration, IndexerScheduler};
use indexer::vector::embedder::{azure_open_ai_config, open_ai_config};
use infra::env_util::get_env_var;
use mailer::sender::Mailer;
use snowflake::Snowflake;
@ -298,11 +299,14 @@ pub async fn init_state(config: &Config, rt_cmd_tx: CLCommandSender) -> Result<A
let mailer = get_mailer(&config.mailer).await?;
info!("Setting up Indexer scheduler...");
let open_ai_config = open_ai_config();
let azure_ai_config = azure_open_ai_config();
let embedder_config = IndexerConfiguration {
enable: get_env_var("APPFLOWY_INDEXER_ENABLED", "true")
.parse::<bool>()
.unwrap_or(true),
openai_api_key: Secret::new(get_env_var("AI_OPENAI_API_KEY", "")),
open_ai_config,
azure_ai_config,
embedding_buffer_size: appflowy_collaborate::config::get_env_var(
"APPFLOWY_INDEXER_EMBEDDING_BUFFER_SIZE",
"5000",

View file

@ -4,9 +4,7 @@ use crate::{
api::metrics::RequestMetrics, biz::collab::folder_view::private_space_and_trash_view_ids,
};
use app_error::AppError;
use appflowy_ai_client::dto::{
EmbeddingEncodingFormat, EmbeddingInput, EmbeddingModel, EmbeddingOutput, EmbeddingRequest,
};
use appflowy_ai_client::dto::EmbeddingModel;
use appflowy_collaborate::collab::storage::CollabAccessControlStorage;
use collab_folder::{Folder, View};
use database::collab::GetCollabOrigin;
@ -20,6 +18,7 @@ use shared_entity::dto::search_dto::{
use sqlx::PgPool;
use indexer::scheduler::IndexerScheduler;
use indexer::vector::embedder::{CreateEmbeddingRequestArgs, EmbeddingInput, EncodingFormat};
use uuid::Uuid;
static MAX_SEARCH_DEPTH: i32 = 10;
@ -80,15 +79,18 @@ pub async fn search_document(
request: SearchDocumentRequest,
metrics: &RequestMetrics,
) -> Result<Vec<SearchDocumentResponseItem>, AppError> {
let embeddings = indexer_scheduler
.create_search_embeddings(EmbeddingRequest {
input: EmbeddingInput::String(request.query.clone()),
model: EmbeddingModel::TextEmbedding3Small.to_string(),
encoding_format: EmbeddingEncodingFormat::Float,
dimensions: EmbeddingModel::TextEmbedding3Small.default_dimensions(),
})
let embeddings_request = CreateEmbeddingRequestArgs::default()
.model(EmbeddingModel::TextEmbedding3Small.to_string())
.input(EmbeddingInput::String(request.query.clone()))
.encoding_format(EncodingFormat::Float)
.dimensions(EmbeddingModel::TextEmbedding3Small.default_dimensions())
.build()
.map_err(|err| AppError::Unhandled(err.to_string()))?;
let mut embeddings_resp = indexer_scheduler
.create_search_embeddings(embeddings_request)
.await?;
let total_tokens = embeddings.usage.total_tokens as u32;
let total_tokens = embeddings_resp.usage.total_tokens;
metrics.record_search_tokens_used(&workspace_uuid, total_tokens);
tracing::info!(
"workspace {} OpenAI API search tokens used: {}",
@ -96,18 +98,10 @@ pub async fn search_document(
total_tokens
);
let embedding = embeddings
let embedding = embeddings_resp
.data
.first()
.pop()
.ok_or_else(|| AppError::Internal(anyhow::anyhow!("OpenAI returned no embeddings")))?;
let embedding = match &embedding.embedding {
EmbeddingOutput::Float(vector) => vector.iter().map(|&v| v as f32).collect(),
EmbeddingOutput::Base64(_) => {
return Err(AppError::Internal(anyhow::anyhow!(
"OpenAI returned embeddings in unsupported format"
)))
},
};
let folder = get_latest_collab_folder(
collab_storage,
@ -133,7 +127,7 @@ pub async fn search_document(
workspace_id: workspace_uuid,
limit: request.limit.unwrap_or(10) as i32,
preview: request.preview_size.unwrap_or(500) as i32,
embedding,
embedding: embedding.embedding,
searchable_view_ids: searchable_view_ids.into_iter().collect(),
},
total_tokens,