fix: support index document in the background, fix stack overflow when calling rayon::spawn (#1099)

* chore: batch index

* chore: format log

* chore: index workspace

* chore: fix stack overflow

* chore: background index

* chore: clippy

* chore: filter tasks

* chore: clippy

* chore: add metrics

* chore: fix test
This commit is contained in:
Nathan.fooo 2024-12-24 14:30:17 +08:00 committed by GitHub
parent 381b02a4d0
commit 1131818eb7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
56 changed files with 2185 additions and 355 deletions

View file

@ -1,32 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "\n select c.workspace_id, c.oid, c.partition_key\n from af_collab c\n join af_workspace w on c.workspace_id = w.workspace_id\n where not coalesce(w.settings['disable_search_indexding']::boolean, false)\n and not exists (\n select 1 from af_collab_embeddings em\n where em.oid = c.oid and em.partition_key = 0\n )\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "workspace_id",
"type_info": "Uuid"
},
{
"ordinal": 1,
"name": "oid",
"type_info": "Text"
},
{
"ordinal": 2,
"name": "partition_key",
"type_info": "Int4"
}
],
"parameters": {
"Left": []
},
"nullable": [
false,
false,
false
]
},
"hash": "ad216288cbbe83aba35b5d04705ee5964f1da4f3839c4725a6784c13f2245379"
}

View file

@ -0,0 +1,16 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE af_collab\n SET indexed_at = $1\n WHERE oid = $2 AND partition_key = $3\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Timestamptz",
"Text",
"Int4"
]
},
"nullable": []
},
"hash": "d0e5f5097b35a15f19e9e7faf2c62336d5f130e939331e84c7d834f6028ea673"
}

View file

@ -0,0 +1,35 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT c.workspace_id, c.oid, c.partition_key\n FROM af_collab c\n JOIN af_workspace w ON c.workspace_id = w.workspace_id\n WHERE c.workspace_id = $1\n AND NOT COALESCE(w.settings['disable_search_indexing']::boolean, false)\n AND c.indexed_at IS NULL\n ORDER BY c.updated_at DESC\n LIMIT $2\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "workspace_id",
"type_info": "Uuid"
},
{
"ordinal": 1,
"name": "oid",
"type_info": "Text"
},
{
"ordinal": 2,
"name": "partition_key",
"type_info": "Int4"
}
],
"parameters": {
"Left": [
"Uuid",
"Int8"
]
},
"nullable": [
false,
false,
false
]
},
"hash": "f68cc2042d6aa78feeb33640e9ef13f46c5e10ee269ea0bd965b0e57dee6cf94"
}

View file

@ -0,0 +1,29 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT oid, indexed_at\n FROM af_collab\n WHERE (oid, partition_key) = ANY (\n SELECT UNNEST($1::text[]), UNNEST($2::int[])\n )\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "oid",
"type_info": "Text"
},
{
"ordinal": 1,
"name": "indexed_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"TextArray",
"Int4Array"
]
},
"nullable": [
false,
true
]
},
"hash": "f8c909517885cb30e3f7d573edf47138f90ea9c5fa73eb927cc5487c3d9ad0be"
}

64
Cargo.lock generated
View file

@ -648,6 +648,7 @@ dependencies = [
"hex",
"http 0.2.12",
"image",
"indexer",
"infra",
"itertools 0.11.0",
"lazy_static",
@ -736,6 +737,7 @@ dependencies = [
"futures",
"futures-util",
"governor",
"indexer",
"indexmap 2.3.0",
"itertools 0.12.1",
"lazy_static",
@ -775,6 +777,8 @@ name = "appflowy-worker"
version = "0.1.0"
dependencies = [
"anyhow",
"app-error",
"appflowy-collaborate",
"async_zip",
"aws-config",
"aws-sdk-s3",
@ -790,11 +794,13 @@ dependencies = [
"database-entity",
"dotenvy",
"futures",
"indexer",
"infra",
"mailer",
"md5",
"mime_guess",
"prometheus-client",
"rayon",
"redis 0.25.4",
"reqwest",
"secrecy",
@ -2849,9 +2855,9 @@ dependencies = [
[[package]]
name = "dashmap"
version = "6.0.1"
version = "6.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "804c8821570c3f8b70230c2ba75ffa5c0f9a4189b9a432b6656c536712acae28"
checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf"
dependencies = [
"cfg-if",
"crossbeam-utils",
@ -4195,6 +4201,42 @@ version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "206ca75c9c03ba3d4ace2460e57b189f39f43de612c2f85836e65c929701bb2d"
[[package]]
name = "indexer"
version = "0.1.0"
dependencies = [
"anyhow",
"app-error",
"appflowy-ai-client",
"async-trait",
"bytes",
"chrono",
"collab",
"collab-document",
"collab-entity",
"collab-folder",
"collab-stream",
"dashmap 6.1.0",
"database",
"database-entity",
"futures-util",
"infra",
"prometheus-client",
"rayon",
"redis 0.25.4",
"serde",
"serde_json",
"sqlx",
"thiserror 1.0.63",
"tiktoken-rs",
"tokio",
"tokio-util",
"tracing",
"unicode-segmentation",
"ureq",
"uuid",
]
[[package]]
name = "indexmap"
version = "1.9.3"
@ -7293,11 +7335,11 @@ dependencies = [
[[package]]
name = "thiserror"
version = "2.0.8"
version = "2.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08f5383f3e0071702bf93ab5ee99b52d26936be9dedd9413067cbdcddcb6141a"
checksum = "f072643fd0190df67a8bab670c20ef5d8737177d6ac6b2e9a236cb096206b2cc"
dependencies = [
"thiserror-impl 2.0.8",
"thiserror-impl 2.0.9",
]
[[package]]
@ -7313,9 +7355,9 @@ dependencies = [
[[package]]
name = "thiserror-impl"
version = "2.0.8"
version = "2.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2f357fcec90b3caef6623a099691be676d033b40a058ac95d2a6ade6fa0c943"
checksum = "7b50fa271071aae2e6ee85f842e2e28ba8cd2c5fb67f11fcb1fd70b276f9e7d4"
dependencies = [
"proc-macro2",
"quote",
@ -7882,7 +7924,7 @@ dependencies = [
"native-tls",
"rand 0.8.5",
"sha1",
"thiserror 2.0.8",
"thiserror 2.0.9",
"utf-8",
]
@ -7942,9 +7984,9 @@ checksum = "e4259d9d4425d9f0661581b804cb85fe66a4c631cadd8f490d1c13a35d5d9291"
[[package]]
name = "unicode-segmentation"
version = "1.11.0"
version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202"
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
[[package]]
name = "unicode-width"
@ -8678,7 +8720,7 @@ dependencies = [
"arc-swap",
"async-lock",
"async-trait",
"dashmap 6.0.1",
"dashmap 6.1.0",
"fastrand",
"serde",
"serde_json",

View file

@ -156,6 +156,7 @@ base64.workspace = true
md5.workspace = true
nanoid = "0.4.0"
http.workspace = true
indexer.workspace = true
[dev-dependencies]
once_cell = "1.19.0"
@ -176,7 +177,6 @@ collab-rt-entity = { path = "libs/collab-rt-entity" }
hex = "0.4.3"
unicode-normalization = "0.1.24"
[[bin]]
name = "appflowy_cloud"
path = "src/main.rs"
@ -221,9 +221,11 @@ members = [
"xtask",
"libs/tonic-proto",
"libs/mailer",
"libs/indexer",
]
[workspace.dependencies]
indexer = { path = "libs/indexer" }
collab-rt-entity = { path = "libs/collab-rt-entity" }
collab-rt-protocol = { path = "libs/collab-rt-protocol" }
database = { path = "libs/database" }

View file

@ -157,6 +157,7 @@ APPFLOWY_LOCAL_AI_TEST_ENABLED=false
APPFLOWY_INDEXER_ENABLED=true
APPFLOWY_INDEXER_DATABASE_URL=postgres://${POSTGRES_USER}:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB}
APPFLOWY_INDEXER_REDIS_URL=redis://${REDIS_HOST}:${REDIS_PORT}
APPFLOWY_INDEXER_EMBEDDING_BUFFER_SIZE=5000
# AppFlowy Collaborate
APPFLOWY_COLLABORATE_MULTI_THREAD=false

View file

@ -124,6 +124,7 @@ APPFLOWY_LOCAL_AI_TEST_ENABLED=false
APPFLOWY_INDEXER_ENABLED=true
APPFLOWY_INDEXER_DATABASE_URL=postgres://postgres:password@postgres:5432/postgres
APPFLOWY_INDEXER_REDIS_URL=redis://redis:6379
APPFLOWY_INDEXER_EMBEDDING_BUFFER_SIZE=5000
# AppFlowy Collaborate
APPFLOWY_COLLABORATE_MULTI_THREAD=false

View file

@ -182,6 +182,9 @@ pub enum AppError {
#[error("Decode update error: {0}")]
DecodeUpdateError(String),
#[error("{0}")]
ActionTimeout(String),
#[error("Apply update error:{0}")]
ApplyUpdateError(String),
}
@ -263,6 +266,7 @@ impl AppError {
AppError::ServiceTemporaryUnavailable(_) => ErrorCode::ServiceTemporaryUnavailable,
AppError::DecodeUpdateError(_) => ErrorCode::DecodeUpdateError,
AppError::ApplyUpdateError(_) => ErrorCode::ApplyUpdateError,
AppError::ActionTimeout(_) => ErrorCode::ActionTimeout,
}
}
}
@ -316,6 +320,7 @@ impl From<sqlx::Error> for AppError {
sqlx::Error::RowNotFound => {
AppError::RecordNotFound(format!("Record not exist in db. {})", msg))
},
sqlx::Error::PoolTimedOut => AppError::ActionTimeout(value.to_string()),
_ => AppError::SqlxError(msg),
}
}
@ -424,6 +429,7 @@ pub enum ErrorCode {
ServiceTemporaryUnavailable = 1054,
DecodeUpdateError = 1055,
ApplyUpdateError = 1056,
ActionTimeout = 1057,
}
impl ErrorCode {

View file

@ -6,14 +6,8 @@ use collab::preclude::Collab;
use collab_entity::CollabType;
use tracing::instrument;
#[instrument(level = "trace", skip(data), fields(len = %data.len()))]
#[inline]
pub async fn spawn_blocking_validate_encode_collab(
object_id: &str,
data: &[u8],
collab_type: &CollabType,
) -> Result<(), Error> {
let collab_type = collab_type.clone();
pub async fn collab_from_encode_collab(object_id: &str, data: &[u8]) -> Result<Collab, Error> {
let object_id = object_id.to_string();
let data = data.to_vec();
@ -27,28 +21,19 @@ pub async fn spawn_blocking_validate_encode_collab(
false,
)?;
collab_type.validate_require_data(&collab)?;
Ok::<(), Error>(())
Ok::<_, Error>(collab)
})
.await?
}
#[instrument(level = "trace", skip(data), fields(len = %data.len()))]
#[inline]
pub fn validate_encode_collab(
pub async fn validate_encode_collab(
object_id: &str,
data: &[u8],
collab_type: &CollabType,
) -> Result<(), Error> {
let encoded_collab = EncodedCollab::decode_from_bytes(data)?;
let collab = Collab::new_with_source(
CollabOrigin::Empty,
object_id,
DataSource::DocStateV1(encoded_collab.doc_state.to_vec()),
vec![],
false,
)?;
let collab = collab_from_encode_collab(object_id, data).await?;
collab_type.validate_require_data(&collab)?;
Ok::<(), Error>(())
}

View file

@ -76,7 +76,7 @@ async fn single_group_async_read_message_test() {
}
#[tokio::test]
async fn different_group_read_message_test() {
async fn different_group_read_undelivered_message_test() {
let oid = format!("o{}", random_i64());
let client = stream_client().await;
let mut group_1 = client.collab_update_stream("w1", &oid, "g1").await.unwrap();
@ -101,6 +101,40 @@ async fn different_group_read_message_test() {
assert_eq!(group_2_messages[0].data, vec![1, 2, 3, 4, 5]);
}
#[tokio::test]
async fn different_group_read_message_test() {
let oid = format!("o{}", random_i64());
let client = stream_client().await;
let mut group_1 = client.collab_update_stream("w1", &oid, "g1").await.unwrap();
let mut group_2 = client.collab_update_stream("w1", &oid, "g2").await.unwrap();
let msg = StreamBinary(vec![1, 2, 3, 4, 5]);
{
let client = stream_client().await;
let mut group = client.collab_update_stream("w1", &oid, "g2").await.unwrap();
group.insert_binary(msg).await.unwrap();
}
let msg = group_1
.consumer_messages("consumer1", ReadOption::Count(1))
.await
.unwrap();
group_1.ack_messages(&msg).await.unwrap();
let (result1, result2) = join(
group_1.consumer_messages("consumer1", ReadOption::Count(1)),
group_2.consumer_messages("consumer1", ReadOption::Count(1)),
)
.await;
let group_1_messages = result1.unwrap();
let group_2_messages = result2.unwrap();
// consumer1 already acked the message before. so it should not be available
assert!(group_1_messages.is_empty());
assert_eq!(group_2_messages[0].data, vec![1, 2, 3, 4, 5]);
}
#[tokio::test]
async fn read_specific_num_of_message_test() {
let object_id = format!("o{}", random_i64());

View file

@ -742,7 +742,6 @@ impl From<i16> for AFWorkspaceInvitationStatus {
pub struct AFCollabEmbeddedChunk {
pub fragment_id: String,
pub object_id: String,
pub collab_type: CollabType,
pub content_type: EmbeddingContentType,
pub content: String,
pub embedding: Option<Vec<f32>>,

View file

@ -1,14 +1,17 @@
use crate::collab::partition_key_from_collab_type;
use chrono::{DateTime, Utc};
use collab_entity::CollabType;
use database_entity::dto::{AFCollabEmbeddedChunk, IndexingStatus, QueryCollab, QueryCollabParams};
use futures_util::stream::BoxStream;
use futures_util::StreamExt;
use pgvector::Vector;
use sqlx::pool::PoolConnection;
use sqlx::postgres::{PgHasArrayType, PgTypeInfo};
use sqlx::{Error, Executor, PgPool, Postgres, Transaction};
use sqlx::{Error, Executor, Postgres, Transaction};
use std::collections::HashMap;
use std::ops::DerefMut;
use uuid::Uuid;
use database_entity::dto::{AFCollabEmbeddedChunk, IndexingStatus, QueryCollab, QueryCollabParams};
pub async fn get_index_status<'a, E>(
tx: E,
workspace_id: &Uuid,
@ -89,17 +92,17 @@ impl PgHasArrayType for Fragment {
pub async fn upsert_collab_embeddings(
transaction: &mut Transaction<'_, Postgres>,
workspace_id: &Uuid,
object_id: &str,
collab_type: CollabType,
tokens_used: u32,
records: Vec<AFCollabEmbeddedChunk>,
) -> Result<(), sqlx::Error> {
if records.is_empty() {
return Ok(());
}
let object_id = records[0].object_id.clone();
let collab_type = records[0].collab_type.clone();
let fragments = records.into_iter().map(Fragment::from).collect::<Vec<_>>();
tracing::trace!(
"[Embedding] upsert {} {} fragments",
object_id,
fragments.len()
);
sqlx::query(r#"CALL af_collab_embeddings_upsert($1, $2, $3, $4, $5::af_fragment_v2[])"#)
.bind(*workspace_id)
.bind(object_id)
@ -111,21 +114,26 @@ pub async fn upsert_collab_embeddings(
Ok(())
}
pub fn get_collabs_without_embeddings(pg_pool: &PgPool) -> BoxStream<sqlx::Result<CollabId>> {
// atm. get only documents
pub async fn stream_collabs_without_embeddings(
conn: &mut PoolConnection<Postgres>,
workspace_id: Uuid,
limit: i64,
) -> BoxStream<sqlx::Result<CollabId>> {
sqlx::query!(
r#"
select c.workspace_id, c.oid, c.partition_key
from af_collab c
join af_workspace w on c.workspace_id = w.workspace_id
where not coalesce(w.settings['disable_search_indexding']::boolean, false)
and not exists (
select 1 from af_collab_embeddings em
where em.oid = c.oid and em.partition_key = 0
)
"#
SELECT c.workspace_id, c.oid, c.partition_key
FROM af_collab c
JOIN af_workspace w ON c.workspace_id = w.workspace_id
WHERE c.workspace_id = $1
AND NOT COALESCE(w.settings['disable_search_indexing']::boolean, false)
AND c.indexed_at IS NULL
ORDER BY c.updated_at DESC
LIMIT $2
"#,
workspace_id,
limit
)
.fetch(pg_pool)
.fetch(conn.deref_mut())
.map(|row| {
row.map(|r| CollabId {
collab_type: CollabType::from(r.partition_key),
@ -136,6 +144,71 @@ pub fn get_collabs_without_embeddings(pg_pool: &PgPool) -> BoxStream<sqlx::Resul
.boxed()
}
pub async fn update_collab_indexed_at<'a, E>(
tx: E,
object_id: &str,
collab_type: &CollabType,
indexed_at: DateTime<Utc>,
) -> Result<(), Error>
where
E: Executor<'a, Database = Postgres>,
{
let partition_key = partition_key_from_collab_type(collab_type);
sqlx::query!(
r#"
UPDATE af_collab
SET indexed_at = $1
WHERE oid = $2 AND partition_key = $3
"#,
indexed_at,
object_id,
partition_key
)
.execute(tx)
.await?;
Ok(())
}
pub async fn get_collabs_indexed_at<'a, E>(
executor: E,
collab_ids: Vec<(String, CollabType)>,
) -> Result<HashMap<String, DateTime<Utc>>, Error>
where
E: Executor<'a, Database = Postgres>,
{
let (oids, partition_keys): (Vec<String>, Vec<i32>) = collab_ids
.into_iter()
.map(|(object_id, collab_type)| (object_id, partition_key_from_collab_type(&collab_type)))
.unzip();
let result = sqlx::query!(
r#"
SELECT oid, indexed_at
FROM af_collab
WHERE (oid, partition_key) = ANY (
SELECT UNNEST($1::text[]), UNNEST($2::int[])
)
"#,
&oids,
&partition_keys
)
.fetch_all(executor)
.await?;
let map = result
.into_iter()
.filter_map(|r| {
if let Some(indexed_at) = r.indexed_at {
Some((r.oid, indexed_at))
} else {
None
}
})
.collect::<HashMap<String, DateTime<Utc>>>();
Ok(map)
}
#[derive(Debug, Clone)]
pub struct CollabId {
pub collab_type: CollabType,

40
libs/indexer/Cargo.toml Normal file
View file

@ -0,0 +1,40 @@
[package]
name = "indexer"
version = "0.1.0"
edition = "2021"
[dependencies]
rayon.workspace = true
tiktoken-rs = "0.6.0"
app-error = { workspace = true }
appflowy-ai-client = { workspace = true, features = ["client-api"] }
unicode-segmentation = "1.12.0"
collab = { workspace = true }
collab-entity = { workspace = true }
collab-folder = { workspace = true }
collab-document = { workspace = true }
collab-stream = { workspace = true }
database-entity.workspace = true
database.workspace = true
futures-util.workspace = true
sqlx.workspace = true
tokio.workspace = true
tracing.workspace = true
thiserror = "1.0.56"
uuid.workspace = true
async-trait.workspace = true
serde_json.workspace = true
anyhow.workspace = true
infra.workspace = true
prometheus-client = "0.22.3"
bytes.workspace = true
dashmap = "6.1.0"
chrono = "0.4.39"
ureq = "2.12.1"
serde.workspace = true
redis = { workspace = true, features = [
"aio",
"tokio-comp",
"connection-manager",
] }
tokio-util = "0.7.12"

View file

@ -1,6 +1,6 @@
use crate::indexer::open_ai::split_text_by_max_content_len;
use crate::indexer::vector::embedder::Embedder;
use crate::indexer::Indexer;
use crate::collab_indexer::Indexer;
use crate::vector::embedder::Embedder;
use crate::vector::open_ai::split_text_by_max_content_len;
use anyhow::anyhow;
use app_error::AppError;
use appflowy_ai_client::dto::{
@ -20,7 +20,7 @@ pub struct DocumentIndexer;
#[async_trait]
impl Indexer for DocumentIndexer {
fn create_embedded_chunks(
fn create_embedded_chunks_from_collab(
&self,
collab: &Collab,
embedding_model: EmbeddingModel,
@ -35,9 +35,7 @@ impl Indexer for DocumentIndexer {
let result = document.to_plain_text(collab.transact(), false, true);
match result {
Ok(content) => {
split_text_into_chunks(object_id, content, CollabType::Document, &embedding_model)
},
Ok(content) => self.create_embedded_chunks_from_text(object_id, content, embedding_model),
Err(err) => {
if matches!(err, DocumentError::NoRequiredData) {
Ok(vec![])
@ -48,6 +46,15 @@ impl Indexer for DocumentIndexer {
}
}
fn create_embedded_chunks_from_text(
&self,
object_id: String,
text: String,
model: EmbeddingModel,
) -> Result<Vec<AFCollabEmbeddedChunk>, AppError> {
split_text_into_chunks(object_id, text, CollabType::Document, &model)
}
fn embed(
&self,
embedder: &Embedder,
@ -104,6 +111,10 @@ fn split_text_into_chunks(
embedding_model,
EmbeddingModel::TextEmbedding3Small
));
if content.is_empty() {
return Ok(vec![]);
}
// We assume that every token is ~4 bytes. We're going to split document content into fragments
// of ~2000 tokens each.
let split_contents = split_text_by_max_content_len(content, 8000)?;
@ -115,7 +126,6 @@ fn split_text_into_chunks(
.map(|content| AFCollabEmbeddedChunk {
fragment_id: Uuid::new_v4().to_string(),
object_id: object_id.clone(),
collab_type: collab_type.clone(),
content_type: EmbeddingContentType::PlainText,
content,
embedding: None,

View file

@ -0,0 +1,5 @@
mod document_indexer;
mod provider;
pub use document_indexer::*;
pub use provider::*;

View file

@ -1,22 +1,29 @@
use crate::config::get_env_var;
use crate::indexer::vector::embedder::Embedder;
use crate::indexer::DocumentIndexer;
use crate::collab_indexer::DocumentIndexer;
use crate::vector::embedder::Embedder;
use app_error::AppError;
use appflowy_ai_client::dto::EmbeddingModel;
use collab::preclude::Collab;
use collab_entity::CollabType;
use database_entity::dto::{AFCollabEmbeddedChunk, AFCollabEmbeddings};
use infra::env_util::get_env_var;
use std::collections::HashMap;
use std::sync::Arc;
use tracing::info;
pub trait Indexer: Send + Sync {
fn create_embedded_chunks(
fn create_embedded_chunks_from_collab(
&self,
collab: &Collab,
model: EmbeddingModel,
) -> Result<Vec<AFCollabEmbeddedChunk>, AppError>;
fn create_embedded_chunks_from_text(
&self,
object_id: String,
text: String,
model: EmbeddingModel,
) -> Result<Vec<AFCollabEmbeddedChunk>, AppError>;
fn embed(
&self,
embedder: &Embedder,

View file

@ -0,0 +1,31 @@
use collab::entity::EncodedCollab;
use collab_entity::CollabType;
use database_entity::dto::AFCollabEmbeddedChunk;
use uuid::Uuid;
pub struct UnindexedCollab {
pub workspace_id: Uuid,
pub object_id: String,
pub collab_type: CollabType,
pub collab: EncodedCollab,
}
pub struct EmbeddingRecord {
pub workspace_id: Uuid,
pub object_id: String,
pub collab_type: CollabType,
pub tokens_used: u32,
pub contents: Vec<AFCollabEmbeddedChunk>,
}
impl EmbeddingRecord {
pub fn empty(workspace_id: Uuid, object_id: String, collab_type: CollabType) -> Self {
Self {
workspace_id,
object_id,
collab_type,
tokens_used: 0,
contents: vec![],
}
}
}

View file

@ -0,0 +1,8 @@
#[derive(thiserror::Error, Debug)]
pub enum IndexerError {
#[error("Redis stream group not exist: {0}")]
StreamGroupNotExist(String),
#[error(transparent)]
Internal(#[from] anyhow::Error),
}

9
libs/indexer/src/lib.rs Normal file
View file

@ -0,0 +1,9 @@
pub mod collab_indexer;
pub mod entity;
pub mod error;
pub mod metrics;
pub mod queue;
pub mod scheduler;
pub mod thread_pool;
mod unindexed_workspace;
pub mod vector;

View file

@ -4,8 +4,9 @@ use prometheus_client::registry::Registry;
pub struct EmbeddingMetrics {
total_embed_count: Counter,
failed_embed_count: Counter,
processing_time_histogram: Histogram,
write_embedding_time_histogram: Histogram,
gen_embeddings_time_histogram: Histogram,
fallback_background_tasks: Counter,
}
impl EmbeddingMetrics {
@ -13,8 +14,9 @@ impl EmbeddingMetrics {
Self {
total_embed_count: Counter::default(),
failed_embed_count: Counter::default(),
processing_time_histogram: Histogram::new([500.0, 1000.0, 5000.0, 8000.0].into_iter()),
write_embedding_time_histogram: Histogram::new([500.0, 1000.0, 5000.0, 8000.0].into_iter()),
gen_embeddings_time_histogram: Histogram::new([1000.0, 3000.0, 5000.0, 8000.0].into_iter()),
fallback_background_tasks: Counter::default(),
}
}
@ -33,17 +35,24 @@ impl EmbeddingMetrics {
"Total count of failed embeddings",
metrics.failed_embed_count.clone(),
);
realtime_registry.register(
"processing_time_seconds",
"Histogram of embedding processing times",
metrics.processing_time_histogram.clone(),
);
realtime_registry.register(
"write_embedding_time_seconds",
"Histogram of embedding write times",
metrics.write_embedding_time_histogram.clone(),
);
realtime_registry.register(
"gen_embeddings_time_histogram",
"Histogram of embedding generation times",
metrics.gen_embeddings_time_histogram.clone(),
);
realtime_registry.register(
"fallback_background_tasks",
"Total count of fallback background tasks",
metrics.fallback_background_tasks.clone(),
);
metrics
}
@ -55,13 +64,16 @@ impl EmbeddingMetrics {
self.failed_embed_count.inc_by(count);
}
pub fn record_generate_embedding_time(&self, millis: u128) {
tracing::trace!("[Embedding]: generate embeddings cost: {}ms", millis);
self.processing_time_histogram.observe(millis as f64);
pub fn record_fallback_background_tasks(&self, count: u64) {
self.fallback_background_tasks.inc_by(count);
}
pub fn record_write_embedding_time(&self, millis: u128) {
tracing::trace!("[Embedding]: write embedding time cost: {}ms", millis);
self.write_embedding_time_histogram.observe(millis as f64);
}
pub fn record_gen_embedding_time(&self, num: u32, millis: u128) {
tracing::info!("[Embedding]: index {} collabs cost: {}ms", num, millis);
self.gen_embeddings_time_histogram.observe(millis as f64);
}
}

172
libs/indexer/src/queue.rs Normal file
View file

@ -0,0 +1,172 @@
use crate::error::IndexerError;
use crate::scheduler::UnindexedCollabTask;
use anyhow::anyhow;
use app_error::AppError;
use redis::aio::ConnectionManager;
use redis::streams::{StreamId, StreamReadOptions, StreamReadReply};
use redis::{AsyncCommands, RedisResult, Value};
use serde_json::from_str;
use tracing::error;
pub const INDEX_TASK_STREAM_NAME: &str = "index_collab_task_stream";
const INDEXER_WORKER_GROUP_NAME: &str = "indexer_worker_group";
const INDEXER_CONSUMER_NAME: &str = "appflowy_worker";
impl TryFrom<&StreamId> for UnindexedCollabTask {
type Error = IndexerError;
fn try_from(stream_id: &StreamId) -> Result<Self, Self::Error> {
let task_str = match stream_id.map.get("task") {
Some(value) => match value {
Value::Data(data) => String::from_utf8_lossy(data).to_string(),
_ => {
error!("Unexpected value type for task field: {:?}", value);
return Err(IndexerError::Internal(anyhow!(
"Unexpected value type for task field: {:?}",
value
)));
},
},
None => {
error!("Task field not found in Redis stream entry");
return Err(IndexerError::Internal(anyhow!(
"Task field not found in Redis stream entry"
)));
},
};
from_str::<UnindexedCollabTask>(&task_str).map_err(|err| IndexerError::Internal(err.into()))
}
}
/// Adds a list of tasks to the Redis stream.
///
/// This function pushes a batch of `EmbedderTask` items into the Redis stream for processing.
/// The tasks are serialized into JSON format before being added to the stream.
///
pub async fn add_background_embed_task(
redis_client: ConnectionManager,
tasks: Vec<UnindexedCollabTask>,
) -> Result<(), AppError> {
let items = tasks
.into_iter()
.flat_map(|task| {
let task = serde_json::to_string(&task).ok()?;
Some(("task", task))
})
.collect::<Vec<_>>();
let _: () = redis_client
.clone()
.xadd(INDEX_TASK_STREAM_NAME, "*", &items)
.await
.map_err(|err| {
AppError::Internal(anyhow!(
"Failed to push embedder task to Redis stream: {}",
err
))
})?;
Ok(())
}
/// Reads tasks from the Redis stream for processing by a consumer group.
pub async fn read_background_embed_tasks(
redis_client: &mut ConnectionManager,
options: &StreamReadOptions,
) -> Result<StreamReadReply, IndexerError> {
let tasks: StreamReadReply = match redis_client
.xread_options(&[INDEX_TASK_STREAM_NAME], &[">"], options)
.await
{
Ok(tasks) => tasks,
Err(err) => {
error!("Failed to read tasks from Redis stream: {:?}", err);
if let Some(code) = err.code() {
if code == "NOGROUP" {
return Err(IndexerError::StreamGroupNotExist(
INDEXER_WORKER_GROUP_NAME.to_string(),
));
}
}
return Err(IndexerError::Internal(err.into()));
},
};
Ok(tasks)
}
/// Acknowledges a task in a Redis stream and optionally removes it from the stream.
///
/// It is used to acknowledge the processing of a task in a Redis stream
/// within a specific consumer group. Once a task is acknowledged, it is removed from
/// the **Pending Entries List (PEL)** for the consumer group. If the `delete_task`
/// flag is set to `true`, the task will also be removed from the Redis stream entirely.
///
/// # Parameters:
/// - `redis_client`: A mutable reference to the Redis `ConnectionManager`, used to
/// interact with the Redis server.
/// - `stream_entity_id`: The unique identifier (ID) of the task in the stream.
/// - `delete_task`: A boolean flag that indicates whether the task should be removed
/// from the stream after it is acknowledged. If `true`, the task is deleted from the stream.
/// If `false`, the task remains in the stream after acknowledgment.
pub async fn ack_task(
redis_client: &mut ConnectionManager,
stream_entity_ids: Vec<String>,
delete_task: bool,
) -> Result<(), IndexerError> {
let _: () = redis_client
.xack(
INDEX_TASK_STREAM_NAME,
INDEXER_WORKER_GROUP_NAME,
&stream_entity_ids,
)
.await
.map_err(|err| {
error!("Failed to ack task: {:?}", err);
IndexerError::Internal(err.into())
})?;
if delete_task {
let _: () = redis_client
.xdel(INDEX_TASK_STREAM_NAME, &stream_entity_ids)
.await
.map_err(|err| {
error!("Failed to delete task: {:?}", err);
IndexerError::Internal(err.into())
})?;
}
Ok(())
}
pub fn default_indexer_group_option(limit: usize) -> StreamReadOptions {
StreamReadOptions::default()
.group(INDEXER_WORKER_GROUP_NAME, INDEXER_CONSUMER_NAME)
.count(limit)
}
/// Ensure the consumer group exists, if not, create it.
pub async fn ensure_indexer_consumer_group(
redis_client: &mut ConnectionManager,
) -> Result<(), IndexerError> {
let result: RedisResult<()> = redis_client
.xgroup_create_mkstream(INDEX_TASK_STREAM_NAME, INDEXER_WORKER_GROUP_NAME, "0")
.await;
if let Err(redis_error) = result {
if let Some(code) = redis_error.code() {
if code == "BUSYGROUP" {
return Ok(());
}
if code == "NOGROUP" {
return Err(IndexerError::StreamGroupNotExist(
INDEXER_WORKER_GROUP_NAME.to_string(),
));
}
}
error!("Error when creating consumer group: {:?}", redis_error);
return Err(IndexerError::Internal(redis_error.into()));
}
Ok(())
}

View file

@ -0,0 +1,552 @@
use crate::collab_indexer::{Indexer, IndexerProvider};
use crate::entity::EmbeddingRecord;
use crate::error::IndexerError;
use crate::metrics::EmbeddingMetrics;
use crate::queue::add_background_embed_task;
use crate::thread_pool::{ThreadPoolNoAbort, ThreadPoolNoAbortBuilder};
use crate::vector::embedder::Embedder;
use crate::vector::open_ai;
use app_error::AppError;
use appflowy_ai_client::dto::{EmbeddingRequest, OpenAIEmbeddingResponse};
use collab::lock::RwLock;
use collab::preclude::Collab;
use collab_document::document::DocumentBody;
use collab_entity::CollabType;
use database::collab::CollabStorage;
use database::index::{update_collab_indexed_at, upsert_collab_embeddings};
use database::workspace::select_workspace_settings;
use database_entity::dto::AFCollabEmbeddedChunk;
use infra::env_util::get_env_var;
use rayon::prelude::*;
use redis::aio::ConnectionManager;
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use std::cmp::max;
use std::collections::HashSet;
use std::ops::DerefMut;
use std::sync::{Arc, Weak};
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TrySendError;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::RwLock as TokioRwLock;
use tokio::time::timeout;
use tracing::{debug, error, info, instrument, trace, warn};
use uuid::Uuid;
pub struct IndexerScheduler {
pub(crate) indexer_provider: Arc<IndexerProvider>,
pub(crate) pg_pool: PgPool,
pub(crate) storage: Arc<dyn CollabStorage>,
pub(crate) threads: Arc<ThreadPoolNoAbort>,
#[allow(dead_code)]
pub(crate) metrics: Arc<EmbeddingMetrics>,
write_embedding_tx: UnboundedSender<EmbeddingRecord>,
gen_embedding_tx: mpsc::Sender<UnindexedCollabTask>,
config: IndexerConfiguration,
redis_client: ConnectionManager,
}
#[derive(Debug)]
pub struct IndexerConfiguration {
pub enable: bool,
pub openai_api_key: String,
/// High watermark for the number of embeddings that can be buffered before being written to the database.
pub embedding_buffer_size: usize,
}
impl IndexerScheduler {
pub fn new(
indexer_provider: Arc<IndexerProvider>,
pg_pool: PgPool,
storage: Arc<dyn CollabStorage>,
metrics: Arc<EmbeddingMetrics>,
config: IndexerConfiguration,
redis_client: ConnectionManager,
) -> Arc<Self> {
// Since threads often block while waiting for I/O, you can use more threads than CPU cores to improve concurrency.
// A good rule of thumb is 2x to 10x the number of CPU cores
let num_thread = max(
get_env_var("APPFLOWY_INDEXER_SCHEDULER_NUM_THREAD", "50")
.parse::<usize>()
.unwrap_or(50),
5,
);
info!("Indexer scheduler config: {:?}", config);
let (write_embedding_tx, write_embedding_rx) = unbounded_channel::<EmbeddingRecord>();
let (gen_embedding_tx, gen_embedding_rx) =
mpsc::channel::<UnindexedCollabTask>(config.embedding_buffer_size);
let threads = Arc::new(
ThreadPoolNoAbortBuilder::new()
.num_threads(num_thread)
.thread_name(|index| format!("create-embedding-thread-{index}"))
.build()
.unwrap(),
);
let this = Arc::new(Self {
indexer_provider,
pg_pool,
storage,
threads,
metrics,
write_embedding_tx,
gen_embedding_tx,
config,
redis_client,
});
info!(
"Indexer scheduler is enabled: {}, num threads: {}",
this.index_enabled(),
num_thread
);
let latest_write_embedding_err = Arc::new(TokioRwLock::new(None));
if this.index_enabled() {
tokio::spawn(spawn_rayon_generate_embeddings(
gen_embedding_rx,
Arc::downgrade(&this),
num_thread,
latest_write_embedding_err.clone(),
));
tokio::spawn(spawn_pg_write_embeddings(
write_embedding_rx,
this.pg_pool.clone(),
this.metrics.clone(),
latest_write_embedding_err,
));
}
this
}
fn index_enabled(&self) -> bool {
// if indexing is disabled, return false
if !self.config.enable {
return false;
}
// if openai api key is empty, return false
if self.config.openai_api_key.is_empty() {
return false;
}
true
}
pub fn is_indexing_enabled(&self, collab_type: &CollabType) -> bool {
self.indexer_provider.is_indexing_enabled(collab_type)
}
pub(crate) fn create_embedder(&self) -> Result<Embedder, AppError> {
if self.config.openai_api_key.is_empty() {
return Err(AppError::AIServiceUnavailable(
"OpenAI API key is empty".to_string(),
));
}
Ok(Embedder::OpenAI(open_ai::Embedder::new(
self.config.openai_api_key.clone(),
)))
}
pub fn create_search_embeddings(
&self,
request: EmbeddingRequest,
) -> Result<OpenAIEmbeddingResponse, AppError> {
let embedder = self.create_embedder()?;
let embeddings = embedder.embed(request)?;
Ok(embeddings)
}
pub fn embed_in_background(
&self,
pending_collabs: Vec<UnindexedCollabTask>,
) -> Result<(), AppError> {
if !self.index_enabled() {
return Ok(());
}
let redis_client = self.redis_client.clone();
tokio::spawn(add_background_embed_task(redis_client, pending_collabs));
Ok(())
}
pub fn embed_immediately(&self, pending_collab: UnindexedCollabTask) -> Result<(), AppError> {
if !self.index_enabled() {
return Ok(());
}
if let Err(err) = self.gen_embedding_tx.try_send(pending_collab) {
match err {
TrySendError::Full(pending) => {
warn!("[Embedding] Embedding queue is full, embedding in background");
self.embed_in_background(vec![pending])?;
self.metrics.record_failed_embed_count(1);
},
TrySendError::Closed(_) => {
error!("Failed to send embedding record: channel closed");
},
}
}
Ok(())
}
pub fn index_pending_collab_one(
&self,
pending_collab: UnindexedCollabTask,
background: bool,
) -> Result<(), AppError> {
if !self.index_enabled() {
return Ok(());
}
let indexer = self
.indexer_provider
.indexer_for(&pending_collab.collab_type);
if indexer.is_none() {
return Ok(());
}
if background {
let _ = self.embed_in_background(vec![pending_collab]);
} else {
let _ = self.embed_immediately(pending_collab);
}
Ok(())
}
/// Index all pending collabs in the background
pub fn index_pending_collabs(
&self,
mut pending_collabs: Vec<UnindexedCollabTask>,
) -> Result<(), AppError> {
if !self.index_enabled() {
return Ok(());
}
pending_collabs.retain(|collab| self.is_indexing_enabled(&collab.collab_type));
if pending_collabs.is_empty() {
return Ok(());
}
info!("indexing {} collabs in background", pending_collabs.len());
let _ = self.embed_in_background(pending_collabs);
Ok(())
}
pub async fn index_collab_immediately(
&self,
workspace_id: &str,
object_id: &str,
collab: &Arc<RwLock<Collab>>,
collab_type: &CollabType,
) -> Result<(), AppError> {
if !self.index_enabled() {
return Ok(());
}
if !self.is_indexing_enabled(collab_type) {
return Ok(());
}
match collab_type {
CollabType::Document => {
let lock = collab.read().await;
let txn = lock.transact();
let text = DocumentBody::from_collab(&lock)
.and_then(|body| body.to_plain_text(txn, false, true).ok());
drop(lock); // release the read lock ASAP
if let Some(text) = text {
if !text.is_empty() {
let pending = UnindexedCollabTask::new(
Uuid::parse_str(workspace_id)?,
object_id.to_string(),
collab_type.clone(),
UnindexedData::UnindexedText(text),
);
self.embed_immediately(pending)?;
}
}
},
_ => {
// TODO(nathan): support other collab types
},
}
Ok(())
}
pub async fn can_index_workspace(&self, workspace_id: &str) -> Result<bool, AppError> {
if !self.index_enabled() {
return Ok(false);
}
let uuid = Uuid::parse_str(workspace_id)?;
let settings = select_workspace_settings(&self.pg_pool, &uuid).await?;
match settings {
None => Ok(true),
Some(settings) => Ok(!settings.disable_search_indexing),
}
}
}
async fn spawn_rayon_generate_embeddings(
mut rx: mpsc::Receiver<UnindexedCollabTask>,
scheduler: Weak<IndexerScheduler>,
buffer_size: usize,
latest_write_embedding_err: Arc<TokioRwLock<Option<AppError>>>,
) {
let mut buf = Vec::with_capacity(buffer_size);
loop {
let latest_error = latest_write_embedding_err.write().await.take();
if let Some(err) = latest_error {
if matches!(err, AppError::ActionTimeout(_)) {
info!(
"[Embedding] last write embedding task failed with timeout, waiting for 30s before retrying..."
);
tokio::time::sleep(Duration::from_secs(30)).await;
}
}
let n = rx.recv_many(&mut buf, buffer_size).await;
let scheduler = match scheduler.upgrade() {
Some(scheduler) => scheduler,
None => {
error!("[Embedding] Failed to upgrade scheduler");
break;
},
};
if n == 0 {
info!("[Embedding] Stop generating embeddings");
break;
}
let start = Instant::now();
let records = buf.drain(..n).collect::<Vec<_>>();
trace!(
"[Embedding] received {} embeddings to generate",
records.len()
);
let metrics = scheduler.metrics.clone();
let threads = scheduler.threads.clone();
let indexer_provider = scheduler.indexer_provider.clone();
let write_embedding_tx = scheduler.write_embedding_tx.clone();
let embedder = scheduler.create_embedder();
let result = tokio::task::spawn_blocking(move || {
match embedder {
Ok(embedder) => {
records.into_par_iter().for_each(|record| {
let result = threads.install(|| {
let indexer = indexer_provider.indexer_for(&record.collab_type);
match process_collab(&embedder, indexer, &record.object_id, record.data, &metrics) {
Ok(Some((tokens_used, contents))) => {
if let Err(err) = write_embedding_tx.send(EmbeddingRecord {
workspace_id: record.workspace_id,
object_id: record.object_id,
collab_type: record.collab_type,
tokens_used,
contents,
}) {
error!("Failed to send embedding record: {}", err);
}
},
Ok(None) => {
debug!("No embedding for collab:{}", record.object_id);
},
Err(err) => {
warn!(
"Failed to create embeddings content for collab:{}, error:{}",
record.object_id, err
);
},
}
});
if let Err(err) = result {
error!("Failed to install a task to rayon thread pool: {}", err);
}
});
},
Err(err) => error!("[Embedding] Failed to create embedder: {}", err),
}
Ok::<_, IndexerError>(())
})
.await;
match result {
Ok(Ok(_)) => {
scheduler
.metrics
.record_gen_embedding_time(n as u32, start.elapsed().as_millis());
trace!("Successfully generated embeddings");
},
Ok(Err(err)) => error!("Failed to generate embeddings: {}", err),
Err(err) => error!("Failed to spawn a task to generate embeddings: {}", err),
}
}
}
const EMBEDDING_RECORD_BUFFER_SIZE: usize = 10;
pub async fn spawn_pg_write_embeddings(
mut rx: UnboundedReceiver<EmbeddingRecord>,
pg_pool: PgPool,
metrics: Arc<EmbeddingMetrics>,
latest_write_embedding_error: Arc<TokioRwLock<Option<AppError>>>,
) {
let mut buf = Vec::with_capacity(EMBEDDING_RECORD_BUFFER_SIZE);
loop {
let n = rx.recv_many(&mut buf, EMBEDDING_RECORD_BUFFER_SIZE).await;
if n == 0 {
info!("Stop writing embeddings");
break;
}
trace!("[Embedding] received {} embeddings to write", n);
let start = Instant::now();
let records = buf.drain(..n).collect::<Vec<_>>();
for record in records.iter() {
info!(
"[Embedding] generate collab:{} embeddings, tokens used: {}",
record.object_id, record.tokens_used
);
}
let result = timeout(
Duration::from_secs(20),
batch_insert_records(&pg_pool, records),
)
.await
.unwrap_or_else(|_| {
Err(AppError::ActionTimeout(
"timeout when writing embeddings".to_string(),
))
});
match result {
Ok(_) => {
trace!("[Embedding] save {} embeddings to disk", n);
metrics.record_write_embedding_time(start.elapsed().as_millis());
},
Err(err) => {
error!("Failed to write collab embedding to disk:{}", err);
latest_write_embedding_error.write().await.replace(err);
},
}
}
}
#[instrument(level = "trace", skip_all)]
pub(crate) async fn batch_insert_records(
pg_pool: &PgPool,
records: Vec<EmbeddingRecord>,
) -> Result<(), AppError> {
let mut seen = HashSet::new();
let records = records
.into_iter()
.filter(|record| seen.insert(record.object_id.clone()))
.collect::<Vec<_>>();
let mut txn = pg_pool.begin().await?;
for record in records {
update_collab_indexed_at(
txn.deref_mut(),
&record.object_id,
&record.collab_type,
chrono::Utc::now(),
)
.await?;
upsert_collab_embeddings(
&mut txn,
&record.workspace_id,
&record.object_id,
record.collab_type,
record.tokens_used,
record.contents,
)
.await?;
}
txn.commit().await.map_err(|e| {
error!("[Embedding] Failed to commit transaction: {:?}", e);
e
})?;
Ok(())
}
/// This function must be called within the rayon thread pool.
fn process_collab(
embedder: &Embedder,
indexer: Option<Arc<dyn Indexer>>,
object_id: &str,
data: UnindexedData,
metrics: &EmbeddingMetrics,
) -> Result<Option<(u32, Vec<AFCollabEmbeddedChunk>)>, AppError> {
if let Some(indexer) = indexer {
metrics.record_embed_count(1);
let chunks = match data {
UnindexedData::UnindexedText(text) => {
indexer.create_embedded_chunks_from_text(object_id.to_string(), text, embedder.model())?
},
};
let result = indexer.embed(embedder, chunks);
match result {
Ok(Some(embeddings)) => Ok(Some((embeddings.tokens_consumed, embeddings.params))),
Ok(None) => Ok(None),
Err(err) => {
metrics.record_failed_embed_count(1);
Err(err)
},
}
} else {
Ok(None)
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct UnindexedCollabTask {
pub workspace_id: Uuid,
pub object_id: String,
pub collab_type: CollabType,
pub data: UnindexedData,
pub created_at: i64,
}
impl UnindexedCollabTask {
pub fn new(
workspace_id: Uuid,
object_id: String,
collab_type: CollabType,
data: UnindexedData,
) -> Self {
Self {
workspace_id,
object_id,
collab_type,
data,
created_at: chrono::Utc::now().timestamp(),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub enum UnindexedData {
UnindexedText(String),
}
impl UnindexedData {
pub fn is_empty(&self) -> bool {
match self {
UnindexedData::UnindexedText(text) => text.is_empty(),
}
}
}

View file

@ -0,0 +1,167 @@
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use rayon::{ThreadPool, ThreadPoolBuilder};
use thiserror::Error;
/// A thread pool that does not abort on panics.
///
/// This custom thread pool wraps Rayons `ThreadPool` and ensures that the thread pool
/// can recover from panics gracefully. It detects any panics in worker threads and
/// prevents the entire application from aborting.
#[derive(Debug)]
pub struct ThreadPoolNoAbort {
/// Internal Rayon thread pool.
thread_pool: ThreadPool,
/// Atomic flag to detect if a panic occurred in the thread pool.
catched_panic: Arc<AtomicBool>,
}
impl ThreadPoolNoAbort {
/// Executes a closure within the thread pool.
///
/// This method runs the provided closure (`op`) inside the thread pool. If a panic
/// occurs during the execution, it is detected and returned as an error.
///
/// # Arguments
/// * `op` - A closure that will be executed within the thread pool.
///
/// # Returns
/// * `Ok(R)` - The result of the closure if execution was successful.
/// * `Err(PanicCatched)` - An error indicating that a panic occurred during execution.
///
pub fn install<OP, R>(&self, op: OP) -> Result<R, CatchedPanic>
where
OP: FnOnce() -> R + Send,
R: Send,
{
let output = self.thread_pool.install(op);
// Reset the panic flag and return an error if a panic was detected.
if self.catched_panic.swap(false, Ordering::SeqCst) {
Err(CatchedPanic)
} else {
Ok(output)
}
}
/// Returns the current number of threads in the thread pool.
///
/// # Returns
/// The number of threads being used by the thread pool.
pub fn current_num_threads(&self) -> usize {
self.thread_pool.current_num_threads()
}
}
/// Error indicating that a panic occurred during thread pool execution.
///
/// This error is returned when a closure executed in the thread pool panics.
#[derive(Error, Debug)]
#[error("A panic occurred happened in the thread pool. Check the logs for more information")]
pub struct CatchedPanic;
/// A builder for creating a `ThreadPoolNoAbort` instance.
///
/// This builder wraps Rayons `ThreadPoolBuilder` and customizes the panic handling behavior.
#[derive(Default)]
pub struct ThreadPoolNoAbortBuilder(ThreadPoolBuilder);
impl ThreadPoolNoAbortBuilder {
pub fn new() -> ThreadPoolNoAbortBuilder {
ThreadPoolNoAbortBuilder::default()
}
/// Sets a custom naming function for threads in the pool.
///
/// # Arguments
/// * `closure` - A function that takes a thread index and returns a thread name.
///
pub fn thread_name<F>(mut self, closure: F) -> Self
where
F: FnMut(usize) -> String + 'static,
{
self.0 = self.0.thread_name(closure);
self
}
/// Sets the number of threads for the thread pool.
///
/// # Arguments
/// * `num_threads` - The number of threads to create in the thread pool.
pub fn num_threads(mut self, num_threads: usize) -> ThreadPoolNoAbortBuilder {
self.0 = self.0.num_threads(num_threads);
self
}
/// Builds the `ThreadPoolNoAbort` instance.
///
/// This method creates a `ThreadPoolNoAbort` with the specified configurations,
/// including custom panic handling behavior.
///
/// # Returns
/// * `Ok(ThreadPoolNoAbort)` - The constructed thread pool.
/// * `Err(ThreadPoolBuildError)` - If the thread pool failed to build.
///
pub fn build(mut self) -> Result<ThreadPoolNoAbort, rayon::ThreadPoolBuildError> {
let catched_panic = Arc::new(AtomicBool::new(false));
self.0 = self.0.panic_handler({
let catched_panic = catched_panic.clone();
move |_result| catched_panic.store(true, Ordering::SeqCst)
});
Ok(ThreadPoolNoAbort {
thread_pool: self.0.build()?,
catched_panic,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
#[test]
fn test_install_closure_success() {
// Create a thread pool with 4 threads.
let pool = ThreadPoolNoAbortBuilder::new()
.num_threads(4)
.build()
.expect("Failed to build thread pool");
// Run a closure that executes successfully.
let result = pool.install(|| 42);
// Ensure the result is correct.
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
}
#[test]
fn test_multiple_threads_execution() {
// Create a thread pool with multiple threads.
let pool = ThreadPoolNoAbortBuilder::new()
.num_threads(8)
.build()
.expect("Failed to build thread pool");
// Shared atomic counter to verify parallel execution.
let counter = Arc::new(AtomicUsize::new(0));
let handles: Vec<_> = (0..100)
.map(|_| {
let counter_clone = counter.clone();
pool.install(move || {
counter_clone.fetch_add(1, Ordering::SeqCst);
})
})
.collect();
// Ensure all tasks completed successfully.
for handle in handles {
assert!(handle.is_ok());
}
// Verify that the counter equals the number of tasks executed.
assert_eq!(counter.load(Ordering::SeqCst), 100);
}
}

View file

@ -0,0 +1,223 @@
use crate::collab_indexer::IndexerProvider;
use crate::entity::{EmbeddingRecord, UnindexedCollab};
use crate::scheduler::{batch_insert_records, IndexerScheduler};
use crate::thread_pool::ThreadPoolNoAbort;
use crate::vector::embedder::Embedder;
use collab::core::collab::DataSource;
use collab::core::origin::CollabOrigin;
use collab::preclude::Collab;
use collab_entity::CollabType;
use database::collab::{CollabStorage, GetCollabOrigin};
use database::index::stream_collabs_without_embeddings;
use futures_util::stream::BoxStream;
use futures_util::StreamExt;
use rayon::iter::ParallelIterator;
use rayon::prelude::IntoParallelIterator;
use sqlx::pool::PoolConnection;
use sqlx::Postgres;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tracing::{error, info, trace};
use uuid::Uuid;
#[allow(dead_code)]
pub(crate) async fn index_workspace(scheduler: Arc<IndexerScheduler>, workspace_id: Uuid) {
let weak_threads = Arc::downgrade(&scheduler.threads);
let mut retry_delay = Duration::from_secs(2);
loop {
let threads = match weak_threads.upgrade() {
Some(threads) => threads,
None => {
info!("[Embedding] thread pool is dropped, stop indexing");
break;
},
};
let conn = scheduler.pg_pool.try_acquire();
if conn.is_none() {
tokio::time::sleep(retry_delay).await;
// 4s, 8s, 16s, 32s, 60s
retry_delay = retry_delay.saturating_mul(2);
if retry_delay > Duration::from_secs(60) {
error!("[Embedding] failed to acquire db connection for 1 minute, stop indexing");
break;
}
continue;
}
retry_delay = Duration::from_secs(2);
let mut conn = conn.unwrap();
let mut stream =
stream_unindexed_collabs(&mut conn, workspace_id, scheduler.storage.clone(), 50).await;
let batch_size = 5;
let mut unindexed_collabs = Vec::with_capacity(batch_size);
while let Some(Ok(collab)) = stream.next().await {
if unindexed_collabs.len() < batch_size {
unindexed_collabs.push(collab);
continue;
}
index_then_write_embedding_to_disk(
&scheduler,
threads.clone(),
std::mem::take(&mut unindexed_collabs),
)
.await;
}
if !unindexed_collabs.is_empty() {
index_then_write_embedding_to_disk(&scheduler, threads.clone(), unindexed_collabs).await;
}
}
}
async fn index_then_write_embedding_to_disk(
scheduler: &Arc<IndexerScheduler>,
threads: Arc<ThreadPoolNoAbort>,
unindexed_collabs: Vec<UnindexedCollab>,
) {
info!(
"[Embedding] process batch {:?} embeddings",
unindexed_collabs
.iter()
.map(|v| v.object_id.clone())
.collect::<Vec<_>>()
);
if let Ok(embedder) = scheduler.create_embedder() {
let start = Instant::now();
let embeddings = create_embeddings(
embedder,
&scheduler.indexer_provider,
threads.clone(),
unindexed_collabs,
)
.await;
scheduler
.metrics
.record_gen_embedding_time(embeddings.len() as u32, start.elapsed().as_millis());
let write_start = Instant::now();
let n = embeddings.len();
match batch_insert_records(&scheduler.pg_pool, embeddings).await {
Ok(_) => trace!(
"[Embedding] upsert {} embeddings success, cost:{}ms",
n,
write_start.elapsed().as_millis()
),
Err(err) => error!("{}", err),
}
scheduler
.metrics
.record_write_embedding_time(write_start.elapsed().as_millis());
tokio::time::sleep(Duration::from_secs(5)).await;
} else {
trace!("[Embedding] no embeddings to process in this batch");
}
}
async fn stream_unindexed_collabs(
conn: &mut PoolConnection<Postgres>,
workspace_id: Uuid,
storage: Arc<dyn CollabStorage>,
limit: i64,
) -> BoxStream<Result<UnindexedCollab, anyhow::Error>> {
let cloned_storage = storage.clone();
stream_collabs_without_embeddings(conn, workspace_id, limit)
.await
.map(move |result| {
let storage = cloned_storage.clone();
async move {
match result {
Ok(cid) => match cid.collab_type {
CollabType::Document => {
let collab = storage
.get_encode_collab(GetCollabOrigin::Server, cid.clone().into(), false)
.await?;
Ok(Some(UnindexedCollab {
workspace_id: cid.workspace_id,
object_id: cid.object_id,
collab_type: cid.collab_type,
collab,
}))
},
// TODO(nathan): support other collab types
_ => Ok::<_, anyhow::Error>(None),
},
Err(e) => Err(e.into()),
}
}
})
.filter_map(|future| async {
match future.await {
Ok(Some(unindexed_collab)) => Some(Ok(unindexed_collab)),
Ok(None) => None,
Err(e) => Some(Err(e)),
}
})
.boxed()
}
async fn create_embeddings(
embedder: Embedder,
indexer_provider: &Arc<IndexerProvider>,
threads: Arc<ThreadPoolNoAbort>,
unindexed_records: Vec<UnindexedCollab>,
) -> Vec<EmbeddingRecord> {
unindexed_records
.into_par_iter()
.flat_map(|unindexed| {
let indexer = indexer_provider.indexer_for(&unindexed.collab_type)?;
let collab = Collab::new_with_source(
CollabOrigin::Empty,
&unindexed.object_id,
DataSource::DocStateV1(unindexed.collab.doc_state.into()),
vec![],
false,
)
.ok()?;
let chunks = indexer
.create_embedded_chunks_from_collab(&collab, embedder.model())
.ok()?;
if chunks.is_empty() {
trace!("[Embedding] {} has no embeddings", unindexed.object_id,);
return Some(EmbeddingRecord::empty(
unindexed.workspace_id,
unindexed.object_id,
unindexed.collab_type,
));
}
let result = threads.install(|| match indexer.embed(&embedder, chunks) {
Ok(embeddings) => embeddings.map(|embeddings| EmbeddingRecord {
workspace_id: unindexed.workspace_id,
object_id: unindexed.object_id,
collab_type: unindexed.collab_type,
tokens_used: embeddings.tokens_consumed,
contents: embeddings.params,
}),
Err(err) => {
error!("Failed to embed collab: {}", err);
None
},
});
if let Ok(Some(record)) = &result {
trace!(
"[Embedding] generate collab:{} embeddings, tokens used: {}",
record.object_id,
record.tokens_used
);
}
result.unwrap_or_else(|err| {
error!("Failed to spawn a task to index collab: {}", err);
None
})
})
.collect::<Vec<_>>()
}

View file

@ -1,4 +1,4 @@
use crate::indexer::vector::open_ai;
use crate::vector::open_ai;
use app_error::AppError;
use appflowy_ai_client::dto::{EmbeddingModel, EmbeddingRequest, OpenAIEmbeddingResponse};

View file

@ -1,7 +1,74 @@
use crate::vector::rest::check_response;
use anyhow::anyhow;
use app_error::AppError;
use appflowy_ai_client::dto::{EmbeddingRequest, OpenAIEmbeddingResponse};
use serde::de::DeserializeOwned;
use std::time::Duration;
use tiktoken_rs::CoreBPE;
use unicode_segmentation::UnicodeSegmentation;
pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
pub const REQUEST_PARALLELISM: usize = 40;
#[derive(Debug, Clone)]
pub struct Embedder {
bearer: String,
client: ureq::Agent,
}
impl Embedder {
pub fn new(api_key: String) -> Self {
let bearer = format!("Bearer {api_key}");
let client = ureq::AgentBuilder::new()
.max_idle_connections(REQUEST_PARALLELISM * 2)
.max_idle_connections_per_host(REQUEST_PARALLELISM * 2)
.build();
Self { bearer, client }
}
pub fn embed(&self, params: EmbeddingRequest) -> Result<OpenAIEmbeddingResponse, AppError> {
for attempt in 0..3 {
let request = self
.client
.post(OPENAI_EMBEDDINGS_URL)
.set("Authorization", &self.bearer)
.set("Content-Type", "application/json");
let result = check_response(request.send_json(&params));
let retry_duration = match result {
Ok(response) => {
let data = from_response::<OpenAIEmbeddingResponse>(response)?;
return Ok(data);
},
Err(retry) => retry.into_duration(attempt),
}
.map_err(|err| AppError::Internal(err.into()))?;
let retry_duration = retry_duration.min(Duration::from_secs(10));
std::thread::sleep(retry_duration);
}
Err(AppError::Internal(anyhow!(
"Failed to generate embeddings after 3 attempts"
)))
}
}
pub fn from_response<T>(resp: ureq::Response) -> Result<T, anyhow::Error>
where
T: DeserializeOwned,
{
let status_code = resp.status();
if status_code != 200 {
let body = resp.into_string()?;
anyhow::bail!("error code: {}, {}", status_code, body)
}
let resp = resp.into_json()?;
Ok(resp)
}
/// ## Execution Time Comparison Results
///
/// The following results were observed when running `execution_time_comparison_tests`:
@ -128,7 +195,7 @@ pub fn split_text_by_max_content_len(
#[cfg(test)]
mod tests {
use crate::indexer::open_ai::{split_text_by_max_content_len, split_text_by_max_tokens};
use crate::vector::open_ai::{split_text_by_max_content_len, split_text_by_max_tokens};
use tiktoken_rs::cl100k_base;
#[test]

View file

@ -1,4 +1,4 @@
use crate::thread_pool_no_abort::CatchedPanic;
use crate::thread_pool::CatchedPanic;
#[derive(Debug, thiserror::Error)]
#[error("{fault}: {kind}")]

View file

@ -0,0 +1,3 @@
-- Add migration script here
ALTER TABLE af_collab
ADD COLUMN indexed_at TIMESTAMP WITH TIME ZONE DEFAULT NULL;

View file

@ -95,6 +95,7 @@ aws-sdk-s3 = { version = "1.36.0", features = [
"rt-tokio",
] }
zstd.workspace = true
indexer.workspace = true
[dev-dependencies]
rand = "0.8.5"

View file

@ -27,13 +27,14 @@ use crate::collab::cache::CollabCache;
use crate::collab::storage::CollabStorageImpl;
use crate::command::{CLCommandReceiver, CLCommandSender};
use crate::config::{get_env_var, Config, DatabaseSetting, S3Setting};
use crate::indexer::{IndexerConfiguration, IndexerProvider, IndexerScheduler};
use crate::pg_listener::PgListeners;
use crate::snapshot::SnapshotControl;
use crate::state::{AppMetrics, AppState, UserCache};
use crate::CollaborationServer;
use access_control::casbin::access::AccessControl;
use database::file::s3_client_impl::AwsS3BucketClientImpl;
use indexer::collab_indexer::IndexerProvider;
use indexer::scheduler::{IndexerConfiguration, IndexerScheduler};
pub struct Application {
actix_server: Server,
@ -150,10 +151,13 @@ pub async fn init_state(config: &Config, rt_cmd_tx: CLCommandSender) -> Result<A
info!("Setting up Indexer provider...");
let embedder_config = IndexerConfiguration {
enable: crate::config::get_env_var("APPFLOWY_INDEXER_ENABLED", "true")
enable: get_env_var("APPFLOWY_INDEXER_ENABLED", "true")
.parse::<bool>()
.unwrap_or(true),
openai_api_key: get_env_var("APPFLOWY_AI_OPENAI_API_KEY", ""),
embedding_buffer_size: get_env_var("APPFLOWY_INDEXER_EMBEDDING_BUFFER_SIZE", "2000")
.parse::<usize>()
.unwrap_or(2000),
};
let indexer_scheduler = IndexerScheduler::new(
IndexerProvider::new(),
@ -161,6 +165,7 @@ pub async fn init_state(config: &Config, rt_cmd_tx: CLCommandSender) -> Result<A
collab_storage.clone(),
metrics.embedding_metrics.clone(),
embedder_config,
redis_conn_manager.clone(),
);
let app_state = AppState {

View file

@ -1,6 +1,6 @@
use app_error::AppError;
use async_trait::async_trait;
use collab_rt_protocol::spawn_blocking_validate_encode_collab;
use collab_rt_protocol::validate_encode_collab;
use database_entity::dto::CollabParams;
#[async_trait]
@ -11,12 +11,8 @@ pub trait CollabValidator {
#[async_trait]
impl CollabValidator for CollabParams {
async fn check_encode_collab(&self) -> Result<(), AppError> {
spawn_blocking_validate_encode_collab(
&self.object_id,
&self.encoded_collab_v1,
&self.collab_type,
)
.await
.map_err(|err| AppError::NoRequiredData(err.to_string()))
validate_encode_collab(&self.object_id, &self.encoded_collab_v1, &self.collab_type)
.await
.map_err(|err| AppError::NoRequiredData(err.to_string()))
}
}

View file

@ -19,13 +19,12 @@ use yrs::{ReadTxn, StateVector};
use collab_stream::error::StreamError;
use database::collab::CollabStorage;
use crate::error::RealtimeError;
use crate::group::broadcast::{CollabBroadcast, Subscription};
use crate::group::persistence::GroupPersistence;
use crate::indexer::IndexerScheduler;
use crate::metrics::CollabRealtimeMetrics;
use database::collab::CollabStorage;
use indexer::scheduler::IndexerScheduler;
/// A group used to manage a single [Collab] object
pub struct CollabGroup {
@ -105,7 +104,7 @@ impl CollabGroup {
pub async fn generate_embeddings(&self) {
let result = self
.indexer_scheduler
.index_collab(
.index_collab_immediately(
&self.workspace_id,
&self.object_id,
&self.collab,

View file

@ -13,15 +13,14 @@ use app_error::AppError;
use collab_rt_entity::user::RealtimeUser;
use collab_rt_entity::CollabMessage;
use database::collab::{CollabStorage, GetCollabOrigin};
use database_entity::dto::QueryCollabParams;
use crate::client::client_msg_router::ClientMessageRouter;
use crate::error::{CreateGroupFailedReason, RealtimeError};
use crate::group::group_init::CollabGroup;
use crate::group::state::GroupManagementState;
use crate::indexer::IndexerScheduler;
use crate::metrics::CollabRealtimeMetrics;
use database::collab::{CollabStorage, GetCollabOrigin};
use database_entity::dto::QueryCollabParams;
use indexer::scheduler::IndexerScheduler;
pub struct GroupManager<S> {
state: GroupManagementState,

View file

@ -1,20 +1,20 @@
use std::sync::Arc;
use std::time::Duration;
use crate::group::group_init::EditState;
use anyhow::anyhow;
use app_error::AppError;
use collab::lock::RwLock;
use collab::preclude::Collab;
use collab_document::document::DocumentBody;
use collab_entity::{validate_data_for_folder, CollabType};
use database::collab::CollabStorage;
use database_entity::dto::CollabParams;
use indexer::scheduler::{IndexerScheduler, UnindexedCollabTask, UnindexedData};
use tokio::time::interval;
use tokio_util::sync::CancellationToken;
use tracing::{trace, warn};
use app_error::AppError;
use database::collab::CollabStorage;
use database_entity::dto::CollabParams;
use crate::group::group_init::EditState;
use crate::indexer::IndexerScheduler;
use uuid::Uuid;
pub(crate) struct GroupPersistence<S> {
workspace_id: String,
@ -124,17 +124,38 @@ where
let collab_type = self.collab_type.clone();
let cloned_collab = self.collab.clone();
let indexer_scheduler = self.indexer_scheduler.clone();
let params = tokio::task::spawn_blocking(move || {
let collab = cloned_collab.blocking_read();
let params = get_encode_collab(&workspace_id, &object_id, &collab, &collab_type)?;
match collab_type {
CollabType::Document => {
let txn = collab.transact();
let text = DocumentBody::from_collab(&collab)
.and_then(|doc| doc.to_plain_text(txn, false, true).ok());
if let Some(text) = text {
let pending = UnindexedCollabTask::new(
Uuid::parse_str(&workspace_id)?,
object_id.clone(),
collab_type,
UnindexedData::UnindexedText(text),
);
if let Err(err) = indexer_scheduler.index_pending_collab_one(pending, true) {
warn!("fail to index collab: {}:{}", object_id, err);
}
}
},
_ => {
// TODO(nathan): support other collab types
},
}
Ok::<_, AppError>(params)
})
.await??;
self
.indexer_scheduler
.index_encoded_collab_one(&self.workspace_id, &params)?;
self
.storage
.queue_insert_or_update_collab(&self.workspace_id, &self.uid, params, flush_to_disk)

View file

@ -1,10 +0,0 @@
mod document_indexer;
mod indexer_scheduler;
pub mod metrics;
mod open_ai;
mod provider;
mod vector;
pub use document_indexer::DocumentIndexer;
pub use indexer_scheduler::*;
pub use provider::*;

View file

@ -1,68 +0,0 @@
use crate::indexer::vector::rest::check_response;
use anyhow::anyhow;
use app_error::AppError;
use appflowy_ai_client::dto::{EmbeddingRequest, OpenAIEmbeddingResponse};
use serde::de::DeserializeOwned;
use std::time::Duration;
pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
pub const REQUEST_PARALLELISM: usize = 40;
#[derive(Debug, Clone)]
pub struct Embedder {
bearer: String,
client: ureq::Agent,
}
impl Embedder {
pub fn new(api_key: String) -> Self {
let bearer = format!("Bearer {api_key}");
let client = ureq::AgentBuilder::new()
.max_idle_connections(REQUEST_PARALLELISM * 2)
.max_idle_connections_per_host(REQUEST_PARALLELISM * 2)
.build();
Self { bearer, client }
}
pub fn embed(&self, params: EmbeddingRequest) -> Result<OpenAIEmbeddingResponse, AppError> {
for attempt in 0..3 {
let request = self
.client
.post(OPENAI_EMBEDDINGS_URL)
.set("Authorization", &self.bearer)
.set("Content-Type", "application/json");
let result = check_response(request.send_json(&params));
let retry_duration = match result {
Ok(response) => {
let data = from_response::<OpenAIEmbeddingResponse>(response)?;
return Ok(data);
},
Err(retry) => retry.into_duration(attempt),
}
.map_err(|err| AppError::Internal(err.into()))?;
let retry_duration = retry_duration.min(Duration::from_secs(10));
std::thread::sleep(retry_duration);
}
Err(AppError::Internal(anyhow!(
"Failed to generate embeddings after 3 attempts"
)))
}
}
pub fn from_response<T>(resp: ureq::Response) -> Result<T, anyhow::Error>
where
T: DeserializeOwned,
{
let status_code = resp.status();
if status_code != 200 {
let body = resp.into_string()?;
anyhow::bail!("error code: {}, {}", status_code, body)
}
let resp = resp.into_json()?;
Ok(resp)
}

View file

@ -9,7 +9,6 @@ pub mod config;
pub mod connect_state;
pub mod error;
mod group;
pub mod indexer;
pub mod metrics;
mod permission;
mod pg_listener;

View file

@ -15,8 +15,6 @@ use tracing::{error, info, trace, warn};
use yrs::updates::decoder::Decode;
use yrs::StateVector;
use database::collab::CollabStorage;
use crate::client::client_msg_router::ClientMessageRouter;
use crate::command::{spawn_collaboration_command, CLCommandReceiver};
use crate::config::get_env_var;
@ -24,8 +22,9 @@ use crate::connect_state::ConnectState;
use crate::error::{CreateGroupFailedReason, RealtimeError};
use crate::group::cmd::{GroupCommand, GroupCommandRunner, GroupCommandSender};
use crate::group::manager::GroupManager;
use crate::indexer::IndexerScheduler;
use crate::rt_server::collaboration_runtime::COLLAB_RUNTIME;
use database::collab::CollabStorage;
use indexer::scheduler::IndexerScheduler;
use crate::actix_ws::entities::{ClientGenerateEmbeddingMessage, ClientHttpUpdateMessage};
use crate::{CollabRealtimeMetrics, RealtimeClientWebsocketSink};

View file

@ -6,17 +6,16 @@ use futures_util::StreamExt;
use sqlx::PgPool;
use uuid::Uuid;
use access_control::metrics::AccessControlMetrics;
use app_error::AppError;
use database::user::{select_all_uid_uuid, select_uid_from_uuid};
use crate::collab::storage::CollabAccessControlStorage;
use crate::config::Config;
use crate::indexer::metrics::EmbeddingMetrics;
use crate::indexer::IndexerScheduler;
use crate::metrics::CollabMetrics;
use crate::pg_listener::PgListeners;
use crate::CollabRealtimeMetrics;
use access_control::metrics::AccessControlMetrics;
use app_error::AppError;
use database::user::{select_all_uid_uuid, select_uid_from_uuid};
use indexer::metrics::EmbeddingMetrics;
use indexer::scheduler::IndexerScheduler;
pub type RedisConnectionManager = redis::aio::ConnectionManager;

View file

@ -9,10 +9,18 @@ pub fn init_subscriber(app_env: &Environment) {
START.call_once(|| {
let level = std::env::var("RUST_LOG").unwrap_or("info".to_string());
let mut filters = vec![];
filters.push(format!("appflowy_collaborate={}", level));
filters.push(format!("actix_web={}", level));
filters.push(format!("collab={}", level));
filters.push(format!("collab_sync={}", level));
filters.push(format!("appflowy_cloud={}", level));
filters.push(format!("collab_plugins={}", level));
filters.push(format!("realtime={}", level));
filters.push(format!("database={}", level));
filters.push(format!("storage={}", level));
filters.push(format!("gotrue={}", level));
filters.push(format!("appflowy_collaborate={}", level));
filters.push(format!("appflowy_ai_client={}", level));
filters.push(format!("indexer={}", level));
let env_filter = EnvFilter::new(filters.join(","));
let builder = tracing_subscriber::fmt()

View file

@ -63,3 +63,9 @@ base64.workspace = true
prometheus-client = "0.22.3"
reqwest.workspace = true
zstd.workspace = true
indexer.workspace = true
appflowy-collaborate = { path = "../appflowy-collaborate" }
rayon = "1.10.0"
app-error = { workspace = true, features = [
"sqlx_error",
] }

View file

@ -15,10 +15,13 @@ use secrecy::ExposeSecret;
use crate::mailer::AFWorkerMailer;
use crate::metric::ImportMetrics;
use appflowy_worker::indexer_worker::{run_background_indexer, BackgroundIndexerConfig};
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::routing::get;
use indexer::metrics::EmbeddingMetrics;
use indexer::thread_pool::ThreadPoolNoAbortBuilder;
use infra::env_util::get_env_var;
use mailer::sender::Mailer;
use std::sync::{Arc, Once};
@ -124,6 +127,28 @@ pub async fn create_app(listener: TcpListener, config: Config) -> Result<(), Err
maximum_import_file_size,
));
let threads = Arc::new(
ThreadPoolNoAbortBuilder::new()
.num_threads(20)
.thread_name(|index| format!("background-embedding-thread-{index}"))
.build()
.unwrap(),
);
tokio::spawn(run_background_indexer(
state.pg_pool.clone(),
state.redis_client.clone(),
state.metrics.embedder_metrics.clone(),
threads.clone(),
BackgroundIndexerConfig {
enable: appflowy_collaborate::config::get_env_var("APPFLOWY_INDEXER_ENABLED", "true")
.parse::<bool>()
.unwrap_or(true),
open_api_key: appflowy_collaborate::config::get_env_var("APPFLOWY_AI_OPENAI_API_KEY", ""),
tick_interval_secs: 10,
},
));
let app = Router::new()
.route("/metrics", get(metrics_handler))
.with_state(Arc::new(state));
@ -212,15 +237,18 @@ pub struct AppMetrics {
#[allow(dead_code)]
registry: Arc<prometheus_client::registry::Registry>,
import_metrics: Arc<ImportMetrics>,
embedder_metrics: Arc<EmbeddingMetrics>,
}
impl AppMetrics {
pub fn new() -> Self {
let mut registry = prometheus_client::registry::Registry::default();
let import_metrics = Arc::new(ImportMetrics::register(&mut registry));
let embedder_metrics = Arc::new(EmbeddingMetrics::register(&mut registry));
Self {
registry: Arc::new(registry),
import_metrics,
embedder_metrics,
}
}
}

View file

@ -59,7 +59,7 @@ use std::sync::Arc;
use std::time::Duration;
use tokio::fs;
use tokio::task::spawn_local;
use tokio::time::interval;
use tokio::time::{interval, MissedTickBehavior};
use tokio_util::compat::TokioAsyncReadCompatExt;
use tracing::{error, info, trace, warn};
use uuid::Uuid;
@ -177,6 +177,7 @@ async fn process_upcoming_tasks(
.group(group_name, consumer_name)
.count(10);
let mut interval = interval(Duration::from_secs(interval_secs));
interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
interval.tick().await;
loop {

View file

@ -0,0 +1,2 @@
mod worker;
pub use worker::*;

View file

@ -0,0 +1,247 @@
use app_error::AppError;
use collab_entity::CollabType;
use database::index::get_collabs_indexed_at;
use indexer::collab_indexer::{Indexer, IndexerProvider};
use indexer::entity::EmbeddingRecord;
use indexer::error::IndexerError;
use indexer::metrics::EmbeddingMetrics;
use indexer::queue::{
ack_task, default_indexer_group_option, ensure_indexer_consumer_group,
read_background_embed_tasks,
};
use indexer::scheduler::{spawn_pg_write_embeddings, UnindexedCollabTask, UnindexedData};
use indexer::thread_pool::ThreadPoolNoAbort;
use indexer::vector::embedder::Embedder;
use indexer::vector::open_ai;
use rayon::prelude::*;
use redis::aio::ConnectionManager;
use sqlx::PgPool;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::sync::RwLock;
use tokio::time::{interval, MissedTickBehavior};
use tracing::{error, info, trace};
#[derive(Debug)]
pub struct BackgroundIndexerConfig {
pub enable: bool,
pub open_api_key: String,
pub tick_interval_secs: u64,
}
pub async fn run_background_indexer(
pg_pool: PgPool,
mut redis_client: ConnectionManager,
embed_metrics: Arc<EmbeddingMetrics>,
threads: Arc<ThreadPoolNoAbort>,
config: BackgroundIndexerConfig,
) {
if !config.enable {
info!("Background indexer is disabled. Stop background indexer");
return;
}
if config.open_api_key.is_empty() {
error!("OpenAI API key is not set. Stop background indexer");
return;
}
let indexer_provider = IndexerProvider::new();
info!("Starting background indexer...");
if let Err(err) = ensure_indexer_consumer_group(&mut redis_client).await {
error!("Failed to ensure indexer consumer group: {:?}", err);
}
let latest_write_embedding_err = Arc::new(RwLock::new(None));
let (write_embedding_tx, write_embedding_rx) = unbounded_channel::<EmbeddingRecord>();
let write_embedding_task_fut = spawn_pg_write_embeddings(
write_embedding_rx,
pg_pool.clone(),
embed_metrics.clone(),
latest_write_embedding_err.clone(),
);
let process_tasks_task_fut = process_upcoming_tasks(
pg_pool,
&mut redis_client,
embed_metrics,
indexer_provider,
threads,
config,
write_embedding_tx,
latest_write_embedding_err,
);
tokio::select! {
_ = write_embedding_task_fut => {
error!("[Background Embedding] Write embedding task stopped");
},
_ = process_tasks_task_fut => {
error!("[Background Embedding] Process tasks task stopped");
},
}
}
#[allow(clippy::too_many_arguments)]
async fn process_upcoming_tasks(
pg_pool: PgPool,
redis_client: &mut ConnectionManager,
metrics: Arc<EmbeddingMetrics>,
indexer_provider: Arc<IndexerProvider>,
threads: Arc<ThreadPoolNoAbort>,
config: BackgroundIndexerConfig,
sender: UnboundedSender<EmbeddingRecord>,
latest_write_embedding_err: Arc<RwLock<Option<AppError>>>,
) {
let options = default_indexer_group_option(threads.current_num_threads());
let mut interval = interval(Duration::from_secs(config.tick_interval_secs));
interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
interval.tick().await;
loop {
interval.tick().await;
let latest_error = latest_write_embedding_err.write().await.take();
if let Some(err) = latest_error {
if matches!(err, AppError::ActionTimeout(_)) {
info!(
"[Background Embedding] last write embedding task failed with timeout, waiting for 30s before retrying..."
);
tokio::time::sleep(Duration::from_secs(15)).await;
}
}
match read_background_embed_tasks(redis_client, &options).await {
Ok(replay) => {
let all_keys: Vec<String> = replay
.keys
.iter()
.flat_map(|key| key.ids.iter().map(|stream_id| stream_id.id.clone()))
.collect();
for key in replay.keys {
info!(
"[Background Embedding] processing {} embedding tasks",
key.ids.len()
);
let mut tasks: Vec<UnindexedCollabTask> = key
.ids
.into_iter()
.filter_map(|stream_id| UnindexedCollabTask::try_from(&stream_id).ok())
.collect();
tasks.retain(|task| !task.data.is_empty());
let collab_ids: Vec<(String, CollabType)> = tasks
.iter()
.map(|task| (task.object_id.clone(), task.collab_type.clone()))
.collect();
let indexed_collabs = get_collabs_indexed_at(&pg_pool, collab_ids)
.await
.unwrap_or_default();
let all_tasks_len = tasks.len();
if !indexed_collabs.is_empty() {
// Filter out tasks where `created_at` is less than `indexed_at`
tasks.retain(|task| {
indexed_collabs
.get(&task.object_id)
.map_or(true, |indexed_at| task.created_at > indexed_at.timestamp())
});
}
if all_tasks_len != tasks.len() {
info!("[Background Embedding] filter out {} tasks where `created_at` is less than `indexed_at`", all_tasks_len - tasks.len());
}
let start = Instant::now();
let num_tasks = tasks.len();
tasks.into_par_iter().for_each(|task| {
let result = threads.install(|| {
if let Some(indexer) = indexer_provider.indexer_for(&task.collab_type) {
let embedder = create_embedder(&config);
let result = handle_task(embedder, indexer, task);
match result {
None => metrics.record_failed_embed_count(1),
Some(record) => {
metrics.record_embed_count(1);
trace!(
"[Background Embedding] send {} embedding record to write task",
record.object_id
);
if let Err(err) = sender.send(record) {
trace!(
"[Background Embedding] failed to send embedding record to write task: {:?}",
err
);
}
},
}
}
});
if let Err(err) = result {
error!(
"[Background Embedding] Failed to process embedder task: {:?}",
err
);
}
});
let cost = start.elapsed().as_millis();
metrics.record_gen_embedding_time(num_tasks as u32, cost);
}
if !all_keys.is_empty() {
match ack_task(redis_client, all_keys, true).await {
Ok(_) => trace!("[Background embedding]: delete tasks from stream"),
Err(err) => {
error!("[Background Embedding] Failed to ack tasks: {:?}", err);
},
}
}
},
Err(err) => {
error!("[Background Embedding] Failed to read tasks: {:?}", err);
if matches!(err, IndexerError::StreamGroupNotExist(_)) {
if let Err(err) = ensure_indexer_consumer_group(redis_client).await {
error!(
"[Background Embedding] Failed to ensure indexer consumer group: {:?}",
err
);
}
}
},
}
}
}
fn handle_task(
embedder: Embedder,
indexer: Arc<dyn Indexer>,
task: UnindexedCollabTask,
) -> Option<EmbeddingRecord> {
trace!(
"[Background Embedding] processing task: {}, content:{:?}, collab_type: {}",
task.object_id,
task.data,
task.collab_type
);
let chunks = match task.data {
UnindexedData::UnindexedText(text) => indexer
.create_embedded_chunks_from_text(task.object_id.clone(), text, embedder.model())
.ok()?,
};
let embeddings = indexer.embed(&embedder, chunks).ok()?;
embeddings.map(|embeddings| EmbeddingRecord {
workspace_id: task.workspace_id,
object_id: task.object_id,
collab_type: task.collab_type,
tokens_used: embeddings.tokens_consumed,
contents: embeddings.params,
})
}
fn create_embedder(config: &BackgroundIndexerConfig) -> Embedder {
Embedder::OpenAI(open_ai::Embedder::new(config.open_api_key.clone()))
}

View file

@ -1,5 +1,6 @@
pub mod error;
pub mod import_worker;
pub mod indexer_worker;
mod mailer;
pub mod metric;
pub mod s3_client;

View file

@ -9,7 +9,7 @@ use async_trait::async_trait;
use byteorder::{ByteOrder, LittleEndian};
use chrono::Utc;
use collab_rt_entity::user::RealtimeUser;
use collab_rt_protocol::spawn_blocking_validate_encode_collab;
use collab_rt_protocol::validate_encode_collab;
use database_entity::dto::CollabParams;
use std::str::FromStr;
use tokio_stream::StreamExt;
@ -119,13 +119,9 @@ pub trait CollabValidator {
#[async_trait]
impl CollabValidator for CollabParams {
async fn check_encode_collab(&self) -> Result<(), AppError> {
spawn_blocking_validate_encode_collab(
&self.object_id,
&self.encoded_collab_v1,
&self.collab_type,
)
.await
.map_err(|err| AppError::NoRequiredData(err.to_string()))
validate_encode_collab(&self.object_id, &self.encoded_collab_v1, &self.collab_type)
.await
.map_err(|err| AppError::NoRequiredData(err.to_string()))
}
}

View file

@ -1,5 +1,5 @@
use crate::api::util::{client_version_from_headers, realtime_user_for_web_request, PayloadReader};
use crate::api::util::{compress_type_from_header_value, device_id_from_headers, CollabValidator};
use crate::api::util::{compress_type_from_header_value, device_id_from_headers};
use crate::api::ws::RealtimeServerAddr;
use crate::biz;
use crate::biz::collab::ops::{
@ -32,23 +32,28 @@ use actix_web::{HttpRequest, Result};
use anyhow::{anyhow, Context};
use app_error::AppError;
use appflowy_collaborate::actix_ws::entities::{ClientHttpStreamMessage, ClientHttpUpdateMessage};
use appflowy_collaborate::indexer::IndexedCollab;
use authentication::jwt::{Authorization, OptionalUserUuid, UserUuid};
use bytes::BytesMut;
use chrono::{DateTime, Duration, Utc};
use collab::core::collab::DataSource;
use collab::core::origin::CollabOrigin;
use collab::entity::EncodedCollab;
use collab::preclude::Collab;
use collab_database::entity::FieldType;
use collab_document::document::Document;
use collab_entity::CollabType;
use collab_folder::timestamp;
use collab_rt_entity::collab_proto::{CollabDocStateParams, PayloadCompressionType};
use collab_rt_entity::realtime_proto::HttpRealtimeMessage;
use collab_rt_entity::user::RealtimeUser;
use collab_rt_entity::RealtimeMessage;
use collab_rt_protocol::validate_encode_collab;
use collab_rt_protocol::collab_from_encode_collab;
use database::collab::{CollabStorage, GetCollabOrigin};
use database::user::select_uid_from_email;
use database_entity::dto::PublishCollabItem;
use database_entity::dto::PublishInfo;
use database_entity::dto::*;
use indexer::scheduler::{UnindexedCollabTask, UnindexedData};
use prost::Message as ProstMessage;
use rayon::prelude::*;
use sha2::{Digest, Sha256};
@ -63,6 +68,7 @@ use tokio_tungstenite::tungstenite::Message;
use tracing::{error, event, instrument, trace};
use uuid::Uuid;
use validator::Validate;
pub const WORKSPACE_ID_PATH: &str = "workspace_id";
pub const COLLAB_OBJECT_ID_PATH: &str = "object_id";
@ -706,7 +712,16 @@ async fn create_collab_handler(
);
}
if let Err(err) = params.check_encode_collab().await {
let collab = collab_from_encode_collab(&params.object_id, &params.encoded_collab_v1)
.await
.map_err(|err| {
AppError::NoRequiredData(format!(
"Failed to create collab from encoded collab: {}",
err
))
})?;
if let Err(err) = params.collab_type.validate_require_data(&collab) {
return Err(
AppError::NoRequiredData(format!(
"collab doc state is not correct:{},{}",
@ -721,9 +736,19 @@ async fn create_collab_handler(
.can_index_workspace(&workspace_id)
.await?
{
state
.indexer_scheduler
.index_encoded_collab_one(&workspace_id, IndexedCollab::from(&params))?;
if let Ok(text) = Document::open(collab).and_then(|doc| doc.to_plain_text(false, true)) {
let workspace_id_uuid =
Uuid::parse_str(&workspace_id).map_err(|err| AppError::Internal(err.into()))?;
let pending = UnindexedCollabTask::new(
workspace_id_uuid,
params.object_id.clone(),
params.collab_type.clone(),
UnindexedData::UnindexedText(text),
);
state
.indexer_scheduler
.index_pending_collab_one(pending, false)?;
}
}
let mut transaction = state
@ -759,7 +784,8 @@ async fn batch_create_collab_handler(
req: HttpRequest,
) -> Result<Json<AppResponse<()>>> {
let uid = state.user_cache.get_user_uid(&user_uuid).await?;
let workspace_id = workspace_id.into_inner().to_string();
let workspace_id_uuid = workspace_id.into_inner();
let workspace_id = workspace_id_uuid.to_string();
let compress_type = compress_type_from_header_value(req.headers())?;
event!(tracing::Level::DEBUG, "start decompressing collab list");
@ -791,7 +817,7 @@ async fn batch_create_collab_handler(
}
}
// Perform decompression and processing in a Rayon thread pool
let collab_params_list = tokio::task::spawn_blocking(move || match compress_type {
let mut collab_params_list = tokio::task::spawn_blocking(move || match compress_type {
CompressionType::Brotli { buffer_size } => offset_len_list
.into_par_iter()
.filter_map(|(offset, len)| {
@ -800,12 +826,31 @@ async fn batch_create_collab_handler(
Ok(decompressed_data) => {
let params = CollabParams::from_bytes(&decompressed_data).ok()?;
if params.validate().is_ok() {
match validate_encode_collab(
let encoded_collab =
EncodedCollab::decode_from_bytes(&params.encoded_collab_v1).ok()?;
let collab = Collab::new_with_source(
CollabOrigin::Empty,
&params.object_id,
&params.encoded_collab_v1,
&params.collab_type,
) {
Ok(_) => Some(params),
DataSource::DocStateV1(encoded_collab.doc_state.to_vec()),
vec![],
false,
)
.ok()?;
match params.collab_type.validate_require_data(&collab) {
Ok(_) => {
match params.collab_type {
CollabType::Document => {
let index_text =
Document::open(collab).and_then(|doc| doc.to_plain_text(false, true));
Some((Some(index_text), params))
},
_ => {
// TODO(nathan): support other types
Some((None, params))
},
}
},
Err(_) => None,
}
} else {
@ -829,7 +874,7 @@ async fn batch_create_collab_handler(
let total_size = collab_params_list
.iter()
.fold(0, |acc, x| acc + x.encoded_collab_v1.len());
.fold(0, |acc, x| acc + x.1.encoded_collab_v1.len());
event!(
tracing::Level::INFO,
"decompressed {} collab objects in {:?}",
@ -837,23 +882,39 @@ async fn batch_create_collab_handler(
start.elapsed()
);
// if state
// .indexer_scheduler
// .can_index_workspace(&workspace_id)
// .await?
// {
// let indexed_collabs: Vec<_> = collab_params_list
// .iter()
// .filter(|p| state.indexer_scheduler.is_indexing_enabled(&p.collab_type))
// .map(IndexedCollab::from)
// .collect();
//
// if !indexed_collabs.is_empty() {
// state
// .indexer_scheduler
// .index_encoded_collabs(&workspace_id, indexed_collabs)?;
// }
// }
let mut pending_undexed_collabs = vec![];
if state
.indexer_scheduler
.can_index_workspace(&workspace_id)
.await?
{
pending_undexed_collabs = collab_params_list
.iter_mut()
.filter(|p| {
state
.indexer_scheduler
.is_indexing_enabled(&p.1.collab_type)
})
.flat_map(|value| match std::mem::take(&mut value.0) {
None => None,
Some(text) => text
.map(|text| {
UnindexedCollabTask::new(
workspace_id_uuid,
value.1.object_id.clone(),
value.1.collab_type.clone(),
UnindexedData::UnindexedText(text),
)
})
.ok(),
})
.collect::<Vec<_>>();
}
let collab_params_list = collab_params_list
.into_iter()
.map(|(_, params)| params)
.collect::<Vec<_>>();
let start = Instant::now();
state
@ -868,6 +929,13 @@ async fn batch_create_collab_handler(
total_size
);
// Must after batch_insert_new_collab
if !pending_undexed_collabs.is_empty() {
state
.indexer_scheduler
.index_pending_collabs(pending_undexed_collabs)?;
}
Ok(Json(AppResponse::Ok()))
}
@ -1366,9 +1434,45 @@ async fn update_collab_handler(
.can_index_workspace(&workspace_id)
.await?
{
state
.indexer_scheduler
.index_encoded_collab_one(&workspace_id, IndexedCollab::from(&params))?;
let workspace_id_uuid =
Uuid::parse_str(&workspace_id).map_err(|err| AppError::Internal(err.into()))?;
match params.collab_type {
CollabType::Document => {
let collab = collab_from_encode_collab(&params.object_id, &params.encoded_collab_v1)
.await
.map_err(|err| {
AppError::InvalidRequest(format!(
"Failed to create collab from encoded collab: {}",
err
))
})?;
params
.collab_type
.validate_require_data(&collab)
.map_err(|err| {
AppError::NoRequiredData(format!(
"collab doc state is not correct:{},{}",
params.object_id, err
))
})?;
if let Ok(text) = Document::open(collab).and_then(|doc| doc.to_plain_text(false, true)) {
let pending = UnindexedCollabTask::new(
workspace_id_uuid,
params.object_id.clone(),
params.collab_type.clone(),
UnindexedData::UnindexedText(text),
);
state
.indexer_scheduler
.index_pending_collab_one(pending, false)?;
}
},
_ => {
// TODO(nathan): support other collab type
},
}
}
state

View file

@ -40,10 +40,11 @@ use appflowy_collaborate::actix_ws::server::RealtimeServerActor;
use appflowy_collaborate::collab::cache::CollabCache;
use appflowy_collaborate::collab::storage::CollabStorageImpl;
use appflowy_collaborate::command::{CLCommandReceiver, CLCommandSender};
use appflowy_collaborate::indexer::{IndexerConfiguration, IndexerProvider, IndexerScheduler};
use appflowy_collaborate::snapshot::SnapshotControl;
use appflowy_collaborate::CollaborationServer;
use database::file::s3_client_impl::{AwsS3BucketClientImpl, S3BucketStorage};
use indexer::collab_indexer::IndexerProvider;
use indexer::scheduler::{IndexerConfiguration, IndexerScheduler};
use infra::env_util::get_env_var;
use mailer::sender::Mailer;
use snowflake::Snowflake;
@ -320,10 +321,16 @@ pub async fn init_state(config: &Config, rt_cmd_tx: CLCommandSender) -> Result<A
info!("Setting up Indexer scheduler...");
let embedder_config = IndexerConfiguration {
enable: appflowy_collaborate::config::get_env_var("APPFLOWY_INDEXER_ENABLED", "true")
enable: get_env_var("APPFLOWY_INDEXER_ENABLED", "true")
.parse::<bool>()
.unwrap_or(true),
openai_api_key: get_env_var("APPFLOWY_AI_OPENAI_API_KEY", ""),
embedding_buffer_size: appflowy_collaborate::config::get_env_var(
"APPFLOWY_INDEXER_EMBEDDING_BUFFER_SIZE",
"5000",
)
.parse::<usize>()
.unwrap_or(5000),
};
let indexer_scheduler = IndexerScheduler::new(
IndexerProvider::new(),
@ -331,6 +338,7 @@ pub async fn init_state(config: &Config, rt_cmd_tx: CLCommandSender) -> Result<A
collab_access_control_storage.clone(),
metrics.embedding_metrics.clone(),
embedder_config,
redis_conn_manager.clone(),
);
info!("Application state initialized");

View file

@ -12,7 +12,7 @@ use shared_entity::dto::search_dto::{
use shared_entity::response::AppResponseError;
use sqlx::PgPool;
use appflowy_collaborate::indexer::IndexerScheduler;
use indexer::scheduler::IndexerScheduler;
use uuid::Uuid;
pub async fn search_document(
@ -23,7 +23,7 @@ pub async fn search_document(
request: SearchDocumentRequest,
metrics: &RequestMetrics,
) -> Result<Vec<SearchDocumentResponseItem>, AppResponseError> {
let embeddings = indexer_scheduler.embeddings(EmbeddingRequest {
let embeddings = indexer_scheduler.create_search_embeddings(EmbeddingRequest {
input: EmbeddingInput::String(request.query.clone()),
model: EmbeddingModel::TextEmbedding3Small.to_string(),
encoding_format: EmbeddingEncodingFormat::Float,

View file

@ -18,6 +18,7 @@ async fn main() -> anyhow::Result<()> {
filters.push(format!("gotrue={}", level));
filters.push(format!("appflowy_collaborate={}", level));
filters.push(format!("appflowy_ai_client={}", level));
filters.push(format!("indexer={}", level));
// Load environment variables from .env file
dotenvy::dotenv().ok();

View file

@ -15,14 +15,13 @@ use app_error::AppError;
use appflowy_ai_client::client::AppFlowyAIClient;
use appflowy_collaborate::collab::cache::CollabCache;
use appflowy_collaborate::collab::storage::CollabAccessControlStorage;
use appflowy_collaborate::indexer::metrics::EmbeddingMetrics;
use appflowy_collaborate::indexer::IndexerScheduler;
use appflowy_collaborate::metrics::CollabMetrics;
use appflowy_collaborate::CollabRealtimeMetrics;
use database::file::s3_client_impl::{AwsS3BucketClientImpl, S3BucketStorage};
use database::user::{select_all_uid_uuid, select_uid_from_uuid};
use gotrue::grant::{Grant, PasswordGrant};
use indexer::metrics::EmbeddingMetrics;
use indexer::scheduler::IndexerScheduler;
use snowflake::Snowflake;
use tonic_proto::history::history_client::HistoryClient;

View file

@ -11,7 +11,7 @@ use reqwest::Method;
use serde::Serialize;
use serde_json::json;
use crate::collab::util::{generate_random_string, test_encode_collab_v1};
use crate::collab::util::{empty_document_editor, generate_random_string, test_encode_collab_v1};
use client_api_test::TestClient;
use shared_entity::response::AppResponse;
use uuid::Uuid;
@ -50,77 +50,6 @@ async fn batch_insert_collab_with_empty_payload_test() {
assert_eq!(error.code, ErrorCode::InvalidRequest);
}
#[tokio::test]
async fn batch_insert_collab_success_test() {
let mut test_client = TestClient::new_user().await;
let workspace_id = test_client.workspace_id().await;
let mut mock_encoded_collab = vec![];
for _ in 0..200 {
let object_id = Uuid::new_v4().to_string();
let encoded_collab_v1 =
test_encode_collab_v1(&object_id, "title", &generate_random_string(2 * 1024));
mock_encoded_collab.push(encoded_collab_v1);
}
for _ in 0..30 {
let object_id = Uuid::new_v4().to_string();
let encoded_collab_v1 =
test_encode_collab_v1(&object_id, "title", &generate_random_string(10 * 1024));
mock_encoded_collab.push(encoded_collab_v1);
}
for _ in 0..10 {
let object_id = Uuid::new_v4().to_string();
let encoded_collab_v1 =
test_encode_collab_v1(&object_id, "title", &generate_random_string(800 * 1024));
mock_encoded_collab.push(encoded_collab_v1);
}
let params_list = mock_encoded_collab
.iter()
.map(|encoded_collab_v1| CollabParams {
object_id: Uuid::new_v4().to_string(),
encoded_collab_v1: encoded_collab_v1.encode_to_bytes().unwrap().into(),
collab_type: CollabType::Unknown,
})
.collect::<Vec<_>>();
test_client
.create_collab_list(&workspace_id, params_list.clone())
.await
.unwrap();
let params = params_list
.iter()
.map(|params| QueryCollab {
object_id: params.object_id.clone(),
collab_type: params.collab_type.clone(),
})
.collect::<Vec<_>>();
let result = test_client
.batch_get_collab(&workspace_id, params)
.await
.unwrap();
for params in params_list {
let encoded_collab = result.0.get(&params.object_id).unwrap();
match encoded_collab {
QueryCollabResult::Success { encode_collab_v1 } => {
let actual = EncodedCollab::decode_from_bytes(encode_collab_v1.as_ref()).unwrap();
let expected = EncodedCollab::decode_from_bytes(params.encoded_collab_v1.as_ref()).unwrap();
assert_eq!(actual.doc_state, expected.doc_state);
},
QueryCollabResult::Failed { error } => {
panic!("Failed to get collab: {:?}", error);
},
}
}
assert_eq!(result.0.values().len(), 240);
}
#[tokio::test]
async fn create_collab_params_compatibility_serde_test() {
// This test is to make sure that the CreateCollabParams is compatible with the old InsertCollabParams
@ -218,6 +147,69 @@ async fn create_collab_compatibility_with_json_params_test() {
assert_eq!(encoded_collab, encoded_collab_from_server);
}
#[tokio::test]
async fn batch_insert_document_collab_test() {
let mut test_client = TestClient::new_user().await;
let workspace_id = test_client.workspace_id().await;
let num_collabs = 100;
let mut list = vec![];
for _ in 0..num_collabs {
let object_id = Uuid::new_v4().to_string();
let mut editor = empty_document_editor(&object_id);
let paragraphs = vec![
generate_random_string(1),
generate_random_string(2),
generate_random_string(5),
];
editor.insert_paragraphs(paragraphs);
list.push((object_id, editor.encode_collab()));
}
let params_list = list
.iter()
.map(|(object_id, encoded_collab_v1)| CollabParams {
object_id: object_id.clone(),
encoded_collab_v1: encoded_collab_v1.encode_to_bytes().unwrap().into(),
collab_type: CollabType::Document,
})
.collect::<Vec<_>>();
test_client
.create_collab_list(&workspace_id, params_list.clone())
.await
.unwrap();
let params = params_list
.iter()
.map(|params| QueryCollab {
object_id: params.object_id.clone(),
collab_type: params.collab_type.clone(),
})
.collect::<Vec<_>>();
let result = test_client
.batch_get_collab(&workspace_id, params)
.await
.unwrap();
for params in params_list {
let encoded_collab = result.0.get(&params.object_id).unwrap();
match encoded_collab {
QueryCollabResult::Success { encode_collab_v1 } => {
let actual = EncodedCollab::decode_from_bytes(encode_collab_v1.as_ref()).unwrap();
let expected = EncodedCollab::decode_from_bytes(params.encoded_collab_v1.as_ref()).unwrap();
assert_eq!(actual.doc_state, expected.doc_state);
},
QueryCollabResult::Failed { error } => {
panic!("Failed to get collab: {:?}", error);
},
}
}
assert_eq!(result.0.values().len(), num_collabs);
}
#[derive(Debug, Clone, Serialize)]
pub struct OldCreateCollabParams {
#[serde(flatten)]

View file

@ -199,7 +199,7 @@ async fn fail_insert_collab_with_empty_payload_test() {
.create_collab(CreateCollabParams {
object_id: Uuid::new_v4().to_string(),
encoded_collab_v1: vec![],
collab_type: CollabType::Unknown,
collab_type: CollabType::Document,
workspace_id: workspace_id.clone(),
})
.await