From 1fd900d9941ea1fedb1a723bf0907863454ff465 Mon Sep 17 00:00:00 2001 From: Bartosz Sypytkowski Date: Sun, 6 Apr 2025 11:47:02 +0200 Subject: [PATCH] Document paragraphs (#1119) * chore: create embeddings by paragraphs * chore: use document paragraphs method * chore: document indexing by paragraphs with consistent hash * chore: compare produced embeddings against existing ones * chore: make pg stored proc compare between input and existing embedded fragments * chore: missing sqlx generation * fix: appflowy worker * chore: make sure that embeddings are only changed when content had changed * chore: remove partition key and recreate af_collab_embeddings_upsert migration * chore: use pg15 on CI and update af_collab_embeddings table primary key * chore: fix test --------- Co-authored-by: Nathan --- ...66411b5034639df91b739f5cbe2af0ffb6811.json | 28 +++ Cargo.lock | 27 ++- Cargo.toml | 14 +- docker-compose-ci.yml | 2 +- docker-compose-dev.yml | 2 +- libs/appflowy-ai-client/src/dto.rs | 2 +- libs/database-entity/src/dto.rs | 2 +- .../src/index/collab_embeddings_ops.rs | 39 +++- libs/indexer/Cargo.toml | 4 +- .../src/collab_indexer/document_indexer.rs | 113 ++++----- libs/indexer/src/collab_indexer/provider.rs | 6 +- libs/indexer/src/scheduler.rs | 216 +++++++++--------- libs/indexer/src/unindexed_workspace.rs | 180 +++++++++------ libs/indexer/src/vector/open_ai.rs | 73 +++--- ...0405092732_af_collab_embeddings_upsert.sql | 91 ++++++++ .../src/group/group_init.rs | 10 +- .../tests/indexer_test.rs | 4 +- .../src/import_worker/worker.rs | 4 +- .../src/indexer_worker/worker.rs | 132 ++++++----- src/api/workspace.rs | 33 +-- src/biz/collab/ops.rs | 7 +- src/biz/collab/utils.rs | 4 +- tests/ai_test/chat_with_selected_doc_test.rs | 2 +- tests/collab/collab_embedding_test.rs | 6 +- tests/collab/database_crud.rs | 8 +- 25 files changed, 601 insertions(+), 408 deletions(-) create mode 100644 .sqlx/query-90afca9cc8b6d4ca31e8ddf1ce466411b5034639df91b739f5cbe2af0ffb6811.json create mode 100644 migrations/20250405092732_af_collab_embeddings_upsert.sql diff --git a/.sqlx/query-90afca9cc8b6d4ca31e8ddf1ce466411b5034639df91b739f5cbe2af0ffb6811.json b/.sqlx/query-90afca9cc8b6d4ca31e8ddf1ce466411b5034639df91b739f5cbe2af0ffb6811.json new file mode 100644 index 00000000..473c8beb --- /dev/null +++ b/.sqlx/query-90afca9cc8b6d4ca31e8ddf1ce466411b5034639df91b739f5cbe2af0ffb6811.json @@ -0,0 +1,28 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT oid, fragment_id\n FROM af_collab_embeddings\n WHERE oid = ANY($1::uuid[])\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "oid", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "fragment_id", + "type_info": "Text" + } + ], + "parameters": { + "Left": [ + "UuidArray" + ] + }, + "nullable": [ + false, + false + ] + }, + "hash": "90afca9cc8b6d4ca31e8ddf1ce466411b5034639df91b739f5cbe2af0ffb6811" +} diff --git a/Cargo.lock b/Cargo.lock index 6c7ddb67..59b96ffc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1888,7 +1888,7 @@ dependencies = [ [[package]] name = "collab" version = "0.2.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=80d1c6147d1139289c2eaadab40557cc86c0f4b6#80d1c6147d1139289c2eaadab40557cc86c0f4b6" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b#3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" dependencies = [ "anyhow", "arc-swap", @@ -1913,7 +1913,7 @@ dependencies = [ [[package]] name = "collab-database" version = "0.2.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=80d1c6147d1139289c2eaadab40557cc86c0f4b6#80d1c6147d1139289c2eaadab40557cc86c0f4b6" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b#3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" dependencies = [ "anyhow", "async-trait", @@ -1953,7 +1953,7 @@ dependencies = [ [[package]] name = "collab-document" version = "0.2.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=80d1c6147d1139289c2eaadab40557cc86c0f4b6#80d1c6147d1139289c2eaadab40557cc86c0f4b6" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b#3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" dependencies = [ "anyhow", "arc-swap", @@ -1974,7 +1974,7 @@ dependencies = [ [[package]] name = "collab-entity" version = "0.2.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=80d1c6147d1139289c2eaadab40557cc86c0f4b6#80d1c6147d1139289c2eaadab40557cc86c0f4b6" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b#3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" dependencies = [ "anyhow", "bytes", @@ -1994,7 +1994,7 @@ dependencies = [ [[package]] name = "collab-folder" version = "0.2.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=80d1c6147d1139289c2eaadab40557cc86c0f4b6#80d1c6147d1139289c2eaadab40557cc86c0f4b6" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b#3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" dependencies = [ "anyhow", "arc-swap", @@ -2016,7 +2016,7 @@ dependencies = [ [[package]] name = "collab-importer" version = "0.1.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=80d1c6147d1139289c2eaadab40557cc86c0f4b6#80d1c6147d1139289c2eaadab40557cc86c0f4b6" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b#3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" dependencies = [ "anyhow", "async-recursion", @@ -2124,7 +2124,7 @@ dependencies = [ [[package]] name = "collab-user" version = "0.2.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=80d1c6147d1139289c2eaadab40557cc86c0f4b6#80d1c6147d1139289c2eaadab40557cc86c0f4b6" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b#3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" dependencies = [ "anyhow", "collab", @@ -3828,8 +3828,6 @@ dependencies = [ "collab", "collab-document", "collab-entity", - "collab-folder", - "collab-stream", "database", "database-entity", "futures-util", @@ -3846,7 +3844,7 @@ dependencies = [ "tiktoken-rs", "tokio", "tracing", - "unicode-segmentation", + "twox-hash", "ureq", "uuid", ] @@ -7412,6 +7410,15 @@ dependencies = [ "utf-8", ] +[[package]] +name = "twox-hash" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7b17f197b3050ba473acf9181f7b1d3b66d1cf7356c6cc57886662276e65908" +dependencies = [ + "rand 0.8.5", +] + [[package]] name = "typenum" version = "1.17.0" diff --git a/Cargo.toml b/Cargo.toml index 1fe6366a..543465ee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -303,13 +303,13 @@ lto = false [patch.crates-io] # It's diffcult to resovle different version with the same crate used in AppFlowy Frontend and the Client-API crate. # So using patch to workaround this issue. -collab = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "80d1c6147d1139289c2eaadab40557cc86c0f4b6" } -collab-entity = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "80d1c6147d1139289c2eaadab40557cc86c0f4b6" } -collab-folder = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "80d1c6147d1139289c2eaadab40557cc86c0f4b6" } -collab-document = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "80d1c6147d1139289c2eaadab40557cc86c0f4b6" } -collab-user = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "80d1c6147d1139289c2eaadab40557cc86c0f4b6" } -collab-database = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "80d1c6147d1139289c2eaadab40557cc86c0f4b6" } -collab-importer = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "80d1c6147d1139289c2eaadab40557cc86c0f4b6" } +collab = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" } +collab-entity = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" } +collab-folder = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" } +collab-document = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" } +collab-user = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" } +collab-database = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" } +collab-importer = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "3b1deca704cc1d8ae4fdc9cb053d7da824d0b85b" } [features] history = [] diff --git a/docker-compose-ci.yml b/docker-compose-ci.yml index 108e41c1..30a93d52 100644 --- a/docker-compose-ci.yml +++ b/docker-compose-ci.yml @@ -35,7 +35,7 @@ services: postgres: restart: on-failure - image: pgvector/pgvector:pg16 + image: pgvector/pgvector:pg15 ports: - "5432:5432" healthcheck: diff --git a/docker-compose-dev.yml b/docker-compose-dev.yml index 2c4803ee..29f69f78 100644 --- a/docker-compose-dev.yml +++ b/docker-compose-dev.yml @@ -21,7 +21,7 @@ services: postgres: restart: on-failure - image: pgvector/pgvector:pg16 + image: pgvector/pgvector:pg15 environment: - POSTGRES_USER=${POSTGRES_USER:-postgres} - POSTGRES_DB=${POSTGRES_DB:-postgres} diff --git a/libs/appflowy-ai-client/src/dto.rs b/libs/appflowy-ai-client/src/dto.rs index bd756dc2..87061abd 100644 --- a/libs/appflowy-ai-client/src/dto.rs +++ b/libs/appflowy-ai-client/src/dto.rs @@ -283,7 +283,7 @@ pub struct EmbeddingRequest { pub dimensions: i32, } -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Serialize, Deserialize, Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum EmbeddingModel { #[serde(rename = "text-embedding-3-small")] TextEmbedding3Small, diff --git a/libs/database-entity/src/dto.rs b/libs/database-entity/src/dto.rs index e30e39ed..994613e8 100644 --- a/libs/database-entity/src/dto.rs +++ b/libs/database-entity/src/dto.rs @@ -762,7 +762,7 @@ pub struct AFCollabEmbeddedChunk { #[serde(with = "uuid_str")] pub object_id: Uuid, pub content_type: EmbeddingContentType, - pub content: String, + pub content: Option, pub embedding: Option>, pub metadata: serde_json::Value, pub fragment_index: i32, diff --git a/libs/database/src/index/collab_embeddings_ops.rs b/libs/database/src/index/collab_embeddings_ops.rs index 5b215701..0d70f709 100644 --- a/libs/database/src/index/collab_embeddings_ops.rs +++ b/libs/database/src/index/collab_embeddings_ops.rs @@ -64,7 +64,7 @@ WHERE w.workspace_id = $1"#, struct Fragment { fragment_id: String, content_type: i32, - contents: String, + contents: Option, embedding: Option, metadata: serde_json::Value, fragment_index: i32, @@ -100,9 +100,13 @@ pub async fn upsert_collab_embeddings( ) -> Result<(), sqlx::Error> { let fragments = records.into_iter().map(Fragment::from).collect::>(); tracing::trace!( - "[Embedding] upsert {} {} fragments", + "[Embedding] upsert {} {} fragments, fragment ids: {:?}", object_id, - fragments.len() + fragments.len(), + fragments + .iter() + .map(|v| v.fragment_id.clone()) + .collect::>() ); sqlx::query(r#"CALL af_collab_embeddings_upsert($1, $2, $3, $4::af_fragment_v3[])"#) .bind(*workspace_id) @@ -114,6 +118,35 @@ pub async fn upsert_collab_embeddings( Ok(()) } +pub async fn get_collab_embedding_fragment_ids<'a, E>( + tx: E, + collab_ids: Vec, +) -> Result>, sqlx::Error> +where + E: Executor<'a, Database = Postgres>, +{ + let records = sqlx::query!( + r#" + SELECT oid, fragment_id + FROM af_collab_embeddings + WHERE oid = ANY($1::uuid[]) + "#, + &collab_ids, + ) + .fetch_all(tx) + .await?; + + let mut fragment_ids_by_oid = HashMap::new(); + for record in records { + // If your record.oid is not a String, convert it as needed. + fragment_ids_by_oid + .entry(record.oid) + .or_insert_with(Vec::new) + .push(record.fragment_id); + } + Ok(fragment_ids_by_oid) +} + pub async fn stream_collabs_without_embeddings( conn: &mut PoolConnection, workspace_id: Uuid, diff --git a/libs/indexer/Cargo.toml b/libs/indexer/Cargo.toml index 1e226fc3..a492a249 100644 --- a/libs/indexer/Cargo.toml +++ b/libs/indexer/Cargo.toml @@ -8,12 +8,9 @@ rayon.workspace = true tiktoken-rs = "0.6.0" app-error = { workspace = true } appflowy-ai-client = { workspace = true, features = ["client-api"] } -unicode-segmentation = "1.12.0" collab = { workspace = true } collab-entity = { workspace = true } -collab-folder = { workspace = true } collab-document = { workspace = true } -collab-stream = { workspace = true } database-entity.workspace = true database.workspace = true futures-util.workspace = true @@ -37,3 +34,4 @@ redis = { workspace = true, features = [ ] } secrecy = { workspace = true, features = ["serde"] } reqwest.workspace = true +twox-hash = { version = "2.1.0", features = ["xxhash64"] } diff --git a/libs/indexer/src/collab_indexer/document_indexer.rs b/libs/indexer/src/collab_indexer/document_indexer.rs index 879fc18c..60a994e9 100644 --- a/libs/indexer/src/collab_indexer/document_indexer.rs +++ b/libs/indexer/src/collab_indexer/document_indexer.rs @@ -1,6 +1,6 @@ use crate::collab_indexer::Indexer; use crate::vector::embedder::Embedder; -use crate::vector::open_ai::split_text_by_max_content_len; +use crate::vector::open_ai::group_paragraphs_by_max_content_len; use anyhow::anyhow; use app_error::AppError; use appflowy_ai_client::dto::{ @@ -9,11 +9,11 @@ use appflowy_ai_client::dto::{ use async_trait::async_trait; use collab::preclude::Collab; use collab_document::document::DocumentBody; -use collab_document::error::DocumentError; use collab_entity::CollabType; use database_entity::dto::{AFCollabEmbeddedChunk, AFCollabEmbeddings, EmbeddingContentType}; use serde_json::json; -use tracing::trace; +use tracing::{debug, trace}; +use twox_hash::xxhash64::Hasher; use uuid::Uuid; pub struct DocumentIndexer; @@ -23,7 +23,7 @@ impl Indexer for DocumentIndexer { fn create_embedded_chunks_from_collab( &self, collab: &Collab, - embedding_model: EmbeddingModel, + model: EmbeddingModel, ) -> Result, AppError> { let object_id = collab.object_id().parse()?; let document = DocumentBody::from_collab(collab).ok_or_else(|| { @@ -33,29 +33,20 @@ impl Indexer for DocumentIndexer { ) })?; - let result = document.to_plain_text(collab.transact(), false, true); - match result { - Ok(content) => self.create_embedded_chunks_from_text(object_id, content, embedding_model), - Err(err) => { - if matches!(err, DocumentError::NoRequiredData) { - Ok(vec![]) - } else { - Err(AppError::Internal(err.into())) - } - }, - } + let paragraphs = document.paragraphs(collab.transact()); + self.create_embedded_chunks_from_text(object_id, paragraphs, model) } fn create_embedded_chunks_from_text( &self, object_id: Uuid, - text: String, + paragraphs: Vec, model: EmbeddingModel, ) -> Result, AppError> { - split_text_into_chunks(object_id, text, CollabType::Document, &model) + split_text_into_chunks(object_id, paragraphs, CollabType::Document, model) } - fn embed( + async fn embed( &self, embedder: &Embedder, mut content: Vec, @@ -66,14 +57,16 @@ impl Indexer for DocumentIndexer { let contents: Vec<_> = content .iter() - .map(|fragment| fragment.content.clone()) + .map(|fragment| fragment.content.clone().unwrap_or_default()) .collect(); - let resp = embedder.embed(EmbeddingRequest { - input: EmbeddingInput::StringArray(contents), - model: embedder.model().name().to_string(), - encoding_format: EmbeddingEncodingFormat::Float, - dimensions: EmbeddingModel::TextEmbedding3Small.default_dimensions(), - })?; + let resp = embedder + .async_embed(EmbeddingRequest { + input: EmbeddingInput::StringArray(contents), + model: embedder.model().name().to_string(), + encoding_format: EmbeddingEncodingFormat::Float, + dimensions: EmbeddingModel::TextEmbedding3Small.default_dimensions(), + }) + .await?; trace!( "[Embedding] request {} embeddings, received {} embeddings", @@ -83,15 +76,18 @@ impl Indexer for DocumentIndexer { for embedding in resp.data { let param = &mut content[embedding.index as usize]; - let embedding: Vec = match embedding.embedding { - EmbeddingOutput::Float(embedding) => embedding.into_iter().map(|f| f as f32).collect(), - EmbeddingOutput::Base64(_) => { - return Err(AppError::OpenError( - "Unexpected base64 encoding".to_string(), - )) - }, - }; - param.embedding = Some(embedding); + if param.content.is_some() { + // we only set the embedding if the content was not marked as unchanged + let embedding: Vec = match embedding.embedding { + EmbeddingOutput::Float(embedding) => embedding.into_iter().map(|f| f as f32).collect(), + EmbeddingOutput::Base64(_) => { + return Err(AppError::OpenError( + "Unexpected base64 encoding".to_string(), + )) + }, + }; + param.embedding = Some(embedding); + } } Ok(Some(AFCollabEmbeddings { @@ -100,39 +96,52 @@ impl Indexer for DocumentIndexer { })) } } - fn split_text_into_chunks( object_id: Uuid, - content: String, + paragraphs: Vec, collab_type: CollabType, - embedding_model: &EmbeddingModel, + embedding_model: EmbeddingModel, ) -> Result, AppError> { debug_assert!(matches!( embedding_model, EmbeddingModel::TextEmbedding3Small )); - if content.is_empty() { + if paragraphs.is_empty() { return Ok(vec![]); } - // We assume that every token is ~4 bytes. We're going to split document content into fragments - // of ~2000 tokens each. - let split_contents = split_text_by_max_content_len(content, 8000)?; - let metadata = json!({"id": object_id.to_string(), "source": "appflowy", "name": "document", "collab_type": collab_type }); - Ok( - split_contents - .into_iter() - .enumerate() - .map(|(index, content)| AFCollabEmbeddedChunk { - fragment_id: Uuid::new_v4().to_string(), + // Group paragraphs into chunks of roughly 8000 characters. + let split_contents = group_paragraphs_by_max_content_len(paragraphs, 8000); + let metadata = json!({ + "id": object_id, + "source": "appflowy", + "name": "document", + "collab_type": collab_type + }); + + let mut seen = std::collections::HashSet::new(); + let mut chunks = Vec::new(); + + for (index, content) in split_contents.into_iter().enumerate() { + let consistent_hash = Hasher::oneshot(0, content.as_bytes()); + let fragment_id = format!("{:x}", consistent_hash); + if seen.insert(fragment_id.clone()) { + chunks.push(AFCollabEmbeddedChunk { + fragment_id, object_id, content_type: EmbeddingContentType::PlainText, - content, + content: Some(content), embedding: None, metadata: metadata.clone(), fragment_index: index as i32, embedded_type: 0, - }) - .collect(), - ) + }); + } else { + debug!( + "[Embedding] Duplicate fragment_id detected: {}. This fragment will not be added.", + fragment_id + ); + } + } + Ok(chunks) } diff --git a/libs/indexer/src/collab_indexer/provider.rs b/libs/indexer/src/collab_indexer/provider.rs index 22bac99e..3004bec5 100644 --- a/libs/indexer/src/collab_indexer/provider.rs +++ b/libs/indexer/src/collab_indexer/provider.rs @@ -2,6 +2,7 @@ use crate::collab_indexer::DocumentIndexer; use crate::vector::embedder::Embedder; use app_error::AppError; use appflowy_ai_client::dto::EmbeddingModel; +use async_trait::async_trait; use collab::preclude::Collab; use collab_entity::CollabType; use database_entity::dto::{AFCollabEmbeddedChunk, AFCollabEmbeddings}; @@ -11,6 +12,7 @@ use std::sync::Arc; use tracing::info; use uuid::Uuid; +#[async_trait] pub trait Indexer: Send + Sync { fn create_embedded_chunks_from_collab( &self, @@ -21,11 +23,11 @@ pub trait Indexer: Send + Sync { fn create_embedded_chunks_from_text( &self, object_id: Uuid, - text: String, + paragraphs: Vec, model: EmbeddingModel, ) -> Result, AppError>; - fn embed( + async fn embed( &self, embedder: &Embedder, content: Vec, diff --git a/libs/indexer/src/scheduler.rs b/libs/indexer/src/scheduler.rs index a3a4d5c4..5eae189c 100644 --- a/libs/indexer/src/scheduler.rs +++ b/libs/indexer/src/scheduler.rs @@ -1,9 +1,7 @@ -use crate::collab_indexer::{Indexer, IndexerProvider}; +use crate::collab_indexer::IndexerProvider; use crate::entity::EmbeddingRecord; -use crate::error::IndexerError; use crate::metrics::EmbeddingMetrics; use crate::queue::add_background_embed_task; -use crate::thread_pool::{ThreadPoolNoAbort, ThreadPoolNoAbortBuilder}; use crate::vector::embedder::Embedder; use crate::vector::open_ai; use app_error::AppError; @@ -12,11 +10,11 @@ use collab::preclude::Collab; use collab_document::document::DocumentBody; use collab_entity::CollabType; use database::collab::CollabStorage; -use database::index::{update_collab_indexed_at, upsert_collab_embeddings}; +use database::index::{ + get_collab_embedding_fragment_ids, update_collab_indexed_at, upsert_collab_embeddings, +}; use database::workspace::select_workspace_settings; -use database_entity::dto::AFCollabEmbeddedChunk; use infra::env_util::get_env_var; -use rayon::prelude::*; use redis::aio::ConnectionManager; use secrecy::{ExposeSecret, Secret}; use serde::{Deserialize, Serialize}; @@ -30,6 +28,7 @@ use tokio::sync::mpsc; use tokio::sync::mpsc::error::TrySendError; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::sync::RwLock as TokioRwLock; +use tokio::task::JoinSet; use tokio::time::timeout; use tracing::{debug, error, info, instrument, trace, warn}; use uuid::Uuid; @@ -38,7 +37,6 @@ pub struct IndexerScheduler { pub(crate) indexer_provider: Arc, pub(crate) pg_pool: PgPool, pub(crate) storage: Arc, - pub(crate) threads: Arc, #[allow(dead_code)] pub(crate) metrics: Arc, write_embedding_tx: UnboundedSender, @@ -77,19 +75,11 @@ impl IndexerScheduler { let (write_embedding_tx, write_embedding_rx) = unbounded_channel::(); let (gen_embedding_tx, gen_embedding_rx) = mpsc::channel::(config.embedding_buffer_size); - let threads = Arc::new( - ThreadPoolNoAbortBuilder::new() - .num_threads(num_thread) - .thread_name(|index| format!("create-embedding-thread-{index}")) - .build() - .unwrap(), - ); let this = Arc::new(Self { indexer_provider, pg_pool, storage, - threads, metrics, write_embedding_tx, gen_embedding_tx, @@ -105,7 +95,7 @@ impl IndexerScheduler { let latest_write_embedding_err = Arc::new(TokioRwLock::new(None)); if this.index_enabled() { - tokio::spawn(spawn_rayon_generate_embeddings( + tokio::spawn(generate_embeddings_loop( gen_embedding_rx, Arc::downgrade(&this), num_thread, @@ -258,18 +248,17 @@ impl IndexerScheduler { CollabType::Document => { let txn = collab.transact(); let text = DocumentBody::from_collab(collab) - .and_then(|body| body.to_plain_text(txn, false, true).ok()); + .map(|body| body.paragraphs(txn)) + .unwrap_or_default(); - if let Some(text) = text { - if !text.is_empty() { - let pending = UnindexedCollabTask::new( - workspace_id, - object_id, - collab_type, - UnindexedData::Text(text), - ); - self.embed_immediately(pending)?; - } + if !text.is_empty() { + let pending = UnindexedCollabTask::new( + workspace_id, + object_id, + collab_type, + UnindexedData::Paragraphs(text), + ); + self.embed_immediately(pending)?; } }, _ => { @@ -293,7 +282,7 @@ impl IndexerScheduler { } } -async fn spawn_rayon_generate_embeddings( +async fn generate_embeddings_loop( mut rx: mpsc::Receiver, scheduler: Weak, buffer_size: usize, @@ -332,60 +321,99 @@ async fn spawn_rayon_generate_embeddings( records.len() ); let metrics = scheduler.metrics.clone(); - let threads = scheduler.threads.clone(); let indexer_provider = scheduler.indexer_provider.clone(); let write_embedding_tx = scheduler.write_embedding_tx.clone(); let embedder = scheduler.create_embedder(); - let result = tokio::task::spawn_blocking(move || { - match embedder { - Ok(embedder) => { - records.into_par_iter().for_each(|record| { - let result = threads.install(|| { - let indexer = indexer_provider.indexer_for(record.collab_type); - match process_collab(&embedder, indexer, record.object_id, record.data, &metrics) { - Ok(Some((tokens_used, contents))) => { - if let Err(err) = write_embedding_tx.send(EmbeddingRecord { - workspace_id: record.workspace_id, - object_id: record.object_id, - collab_type: record.collab_type, - tokens_used, - contents, - }) { - error!("Failed to send embedding record: {}", err); + match embedder { + Ok(embedder) => { + let params: Vec<_> = records.iter().map(|r| r.object_id).collect(); + let existing_embeddings = + match get_collab_embedding_fragment_ids(&scheduler.pg_pool, params).await { + Ok(existing_embeddings) => existing_embeddings, + Err(err) => { + error!("[Embedding] failed to get existing embeddings: {}", err); + Default::default() + }, + }; + let mut join_set = JoinSet::new(); + for record in records { + if let Some(indexer) = indexer_provider.indexer_for(record.collab_type) { + metrics.record_embed_count(1); + let paragraphs = match record.data { + UnindexedData::Paragraphs(paragraphs) => paragraphs, + UnindexedData::Text(text) => text.split('\n').map(|s| s.to_string()).collect(), + }; + let embedder = embedder.clone(); + match indexer.create_embedded_chunks_from_text( + record.object_id, + paragraphs, + embedder.model(), + ) { + Ok(mut chunks) => { + if let Some(fragment_ids) = existing_embeddings.get(&record.object_id) { + for chunk in chunks.iter_mut() { + if fragment_ids.contains(&chunk.fragment_id) { + // we already had an embedding for this chunk + chunk.content = None; + chunk.embedding = None; + } + } + } + join_set.spawn(async move { + if chunks.is_empty() { + return Ok(None); } - }, - Ok(None) => { - debug!("No embedding for collab:{}", record.object_id); - }, - Err(err) => { - warn!( - "Failed to create embeddings content for collab:{}, error:{}", - record.object_id, err - ); - }, - } - }); - if let Err(err) = result { - error!("Failed to install a task to rayon thread pool: {}", err); + let result = indexer.embed(&embedder, chunks).await; + match result { + Ok(Some(embeddings)) => { + let record = EmbeddingRecord { + workspace_id: record.workspace_id, + object_id: record.object_id, + collab_type: record.collab_type, + tokens_used: embeddings.tokens_consumed, + contents: embeddings.params, + }; + Ok(Some(record)) + }, + Ok(None) => Ok(None), + Err(err) => Err(err), + } + }); + }, + Err(err) => { + metrics.record_failed_embed_count(1); + warn!( + "Failed to create embedded chunks for collab: {}, error:{}", + record.object_id, err + ); + continue; + }, } - }); - }, - Err(err) => error!("[Embedding] Failed to create embedder: {}", err), - } - Ok::<_, IndexerError>(()) - }) - .await; - - match result { - Ok(Ok(_)) => { - scheduler - .metrics - .record_gen_embedding_time(n as u32, start.elapsed().as_millis()); - trace!("Successfully generated embeddings"); + } + } + while let Some(Ok(res)) = join_set.join_next().await { + scheduler + .metrics + .record_gen_embedding_time(n as u32, start.elapsed().as_millis()); + match res { + Ok(Some(record)) => { + if let Err(err) = write_embedding_tx.send(record) { + error!("Failed to send embedding record: {}", err); + } + }, + Ok(None) => debug!("No embedding for collab"), + Err(err) => { + metrics.record_failed_embed_count(1); + warn!( + "Failed to create embeddings content for collab, error:{}", + err + ); + }, + } + } }, - Ok(Err(err)) => error!("Failed to generate embeddings: {}", err), - Err(err) => error!("Failed to spawn a task to generate embeddings: {}", err), + Err(err) => error!("[Embedding] Failed to create embedder: {}", err), } } } @@ -409,7 +437,7 @@ pub async fn spawn_pg_write_embeddings( let start = Instant::now(); let records = buf.drain(..n).collect::>(); for record in records.iter() { - info!( + debug!( "[Embedding] generate collab:{} embeddings, tokens used: {}", record.object_id, record.tokens_used ); @@ -477,40 +505,6 @@ pub(crate) async fn batch_insert_records( Ok(()) } -/// This function must be called within the rayon thread pool. -fn process_collab( - embedder: &Embedder, - indexer: Option>, - object_id: Uuid, - data: UnindexedData, - metrics: &EmbeddingMetrics, -) -> Result)>, AppError> { - if let Some(indexer) = indexer { - let chunks = match data { - UnindexedData::Text(text) => { - indexer.create_embedded_chunks_from_text(object_id, text, embedder.model())? - }, - }; - - if chunks.is_empty() { - return Ok(None); - } - - metrics.record_embed_count(1); - let result = indexer.embed(embedder, chunks); - match result { - Ok(Some(embeddings)) => Ok(Some((embeddings.tokens_consumed, embeddings.params))), - Ok(None) => Ok(None), - Err(err) => { - metrics.record_failed_embed_count(1); - Err(err) - }, - } - } else { - Ok(None) - } -} - #[derive(Debug, Serialize, Deserialize)] pub struct UnindexedCollabTask { pub workspace_id: Uuid, @@ -540,12 +534,14 @@ impl UnindexedCollabTask { #[derive(Debug, Serialize, Deserialize)] pub enum UnindexedData { Text(String), + Paragraphs(Vec), } impl UnindexedData { pub fn is_empty(&self) -> bool { match self { UnindexedData::Text(text) => text.is_empty(), + UnindexedData::Paragraphs(text) => text.is_empty(), } } } diff --git a/libs/indexer/src/unindexed_workspace.rs b/libs/indexer/src/unindexed_workspace.rs index d7e441ff..8692544b 100644 --- a/libs/indexer/src/unindexed_workspace.rs +++ b/libs/indexer/src/unindexed_workspace.rs @@ -1,38 +1,31 @@ use crate::collab_indexer::IndexerProvider; use crate::entity::{EmbeddingRecord, UnindexedCollab}; use crate::scheduler::{batch_insert_records, IndexerScheduler}; -use crate::thread_pool::ThreadPoolNoAbort; use crate::vector::embedder::Embedder; +use appflowy_ai_client::dto::EmbeddingModel; use collab::core::collab::DataSource; use collab::core::origin::CollabOrigin; use collab::preclude::Collab; use collab_entity::CollabType; use database::collab::{CollabStorage, GetCollabOrigin}; -use database::index::stream_collabs_without_embeddings; +use database::index::{get_collab_embedding_fragment_ids, stream_collabs_without_embeddings}; use futures_util::stream::BoxStream; use futures_util::StreamExt; use rayon::iter::ParallelIterator; use rayon::prelude::IntoParallelIterator; use sqlx::pool::PoolConnection; use sqlx::Postgres; +use std::collections::HashMap; use std::sync::Arc; use std::time::{Duration, Instant}; +use tokio::task::JoinSet; use tracing::{error, info, trace}; use uuid::Uuid; #[allow(dead_code)] pub(crate) async fn index_workspace(scheduler: Arc, workspace_id: Uuid) { - let weak_threads = Arc::downgrade(&scheduler.threads); let mut retry_delay = Duration::from_secs(2); loop { - let threads = match weak_threads.upgrade() { - Some(threads) => threads, - None => { - info!("[Embedding] thread pool is dropped, stop indexing"); - break; - }, - }; - let conn = scheduler.pg_pool.try_acquire(); if conn.is_none() { tokio::time::sleep(retry_delay).await; @@ -58,23 +51,17 @@ pub(crate) async fn index_workspace(scheduler: Arc, workspace_ continue; } - index_then_write_embedding_to_disk( - &scheduler, - threads.clone(), - std::mem::take(&mut unindexed_collabs), - ) - .await; + index_then_write_embedding_to_disk(&scheduler, std::mem::take(&mut unindexed_collabs)).await; } if !unindexed_collabs.is_empty() { - index_then_write_embedding_to_disk(&scheduler, threads.clone(), unindexed_collabs).await; + index_then_write_embedding_to_disk(&scheduler, unindexed_collabs).await; } } } async fn index_then_write_embedding_to_disk( scheduler: &Arc, - threads: Arc, unindexed_collabs: Vec, ) { info!( @@ -87,32 +74,41 @@ async fn index_then_write_embedding_to_disk( if let Ok(embedder) = scheduler.create_embedder() { let start = Instant::now(); - let embeddings = create_embeddings( - embedder, - &scheduler.indexer_provider, - threads.clone(), - unindexed_collabs, - ) - .await; - scheduler - .metrics - .record_gen_embedding_time(embeddings.len() as u32, start.elapsed().as_millis()); + let object_ids = unindexed_collabs + .iter() + .map(|v| v.object_id) + .collect::>(); + match get_collab_embedding_fragment_ids(&scheduler.pg_pool, object_ids).await { + Ok(existing_embeddings) => { + let embeddings = create_embeddings( + embedder, + &scheduler.indexer_provider, + unindexed_collabs, + existing_embeddings, + ) + .await; + scheduler + .metrics + .record_gen_embedding_time(embeddings.len() as u32, start.elapsed().as_millis()); - let write_start = Instant::now(); - let n = embeddings.len(); - match batch_insert_records(&scheduler.pg_pool, embeddings).await { - Ok(_) => trace!( - "[Embedding] upsert {} embeddings success, cost:{}ms", - n, - write_start.elapsed().as_millis() - ), - Err(err) => error!("{}", err), + let write_start = Instant::now(); + let n = embeddings.len(); + match batch_insert_records(&scheduler.pg_pool, embeddings).await { + Ok(_) => trace!( + "[Embedding] upsert {} embeddings success, cost:{}ms", + n, + write_start.elapsed().as_millis() + ), + Err(err) => error!("{}", err), + } + + scheduler + .metrics + .record_write_embedding_time(write_start.elapsed().as_millis()); + tokio::time::sleep(Duration::from_secs(5)).await; + }, + Err(err) => error!("[Embedding] failed to get fragment ids: {}", err), } - - scheduler - .metrics - .record_write_embedding_time(write_start.elapsed().as_millis()); - tokio::time::sleep(Duration::from_secs(5)).await; } else { trace!("[Embedding] no embeddings to process in this batch"); } @@ -160,12 +156,61 @@ async fn stream_unindexed_collabs( }) .boxed() } - async fn create_embeddings( embedder: Embedder, indexer_provider: &Arc, - threads: Arc, unindexed_records: Vec, + existing_embeddings: HashMap>, +) -> Vec { + // 1. use parallel iteration since computing text chunks is CPU-intensive task + let records = compute_embedding_records( + indexer_provider, + embedder.model(), + unindexed_records, + existing_embeddings, + ); + + // 2. use tokio JoinSet to parallelize OpenAI calls (IO-bound) + let mut join_set = JoinSet::new(); + for record in records { + let indexer_provider = indexer_provider.clone(); + let embedder = embedder.clone(); + if let Some(indexer) = indexer_provider.indexer_for(record.collab_type) { + join_set.spawn(async move { + match indexer.embed(&embedder, record.contents).await { + Ok(embeddings) => embeddings.map(|embeddings| EmbeddingRecord { + workspace_id: record.workspace_id, + object_id: record.object_id, + collab_type: record.collab_type, + tokens_used: embeddings.tokens_consumed, + contents: embeddings.params, + }), + Err(err) => { + error!("Failed to embed collab: {}", err); + None + }, + } + }); + } + } + + let mut results = Vec::with_capacity(join_set.len()); + while let Some(Ok(Some(record))) = join_set.join_next().await { + trace!( + "[Embedding] generate collab:{} embeddings, tokens used: {}", + record.object_id, + record.tokens_used + ); + results.push(record); + } + results +} + +fn compute_embedding_records( + indexer_provider: &IndexerProvider, + model: EmbeddingModel, + unindexed_records: Vec, + existing_embeddings: HashMap>, ) -> Vec { unindexed_records .into_par_iter() @@ -180,8 +225,8 @@ async fn create_embeddings( ) .ok()?; - let chunks = indexer - .create_embedded_chunks_from_collab(&collab, embedder.model()) + let mut chunks = indexer + .create_embedded_chunks_from_collab(&collab, model) .ok()?; if chunks.is_empty() { trace!("[Embedding] {} has no embeddings", unindexed.object_id,); @@ -192,32 +237,23 @@ async fn create_embeddings( )); } - let result = threads.install(|| match indexer.embed(&embedder, chunks) { - Ok(embeddings) => embeddings.map(|embeddings| EmbeddingRecord { - workspace_id: unindexed.workspace_id, - object_id: unindexed.object_id, - collab_type: unindexed.collab_type, - tokens_used: embeddings.tokens_consumed, - contents: embeddings.params, - }), - Err(err) => { - error!("Failed to embed collab: {}", err); - None - }, - }); - - if let Ok(Some(record)) = &result { - trace!( - "[Embedding] generate collab:{} embeddings, tokens used: {}", - record.object_id, - record.tokens_used - ); + // compare chunks against existing fragment ids (which are content addressed) and mark these + // which haven't changed as already embedded + if let Some(existing_embeddings) = existing_embeddings.get(&unindexed.object_id) { + for chunk in chunks.iter_mut() { + if existing_embeddings.contains(&chunk.fragment_id) { + chunk.content = None; // mark as already embedded + chunk.embedding = None; + } + } } - - result.unwrap_or_else(|err| { - error!("Failed to spawn a task to index collab: {}", err); - None + Some(EmbeddingRecord { + workspace_id: unindexed.workspace_id, + object_id: unindexed.object_id, + collab_type: unindexed.collab_type, + tokens_used: 0, + contents: chunks, }) }) - .collect::>() + .collect() } diff --git a/libs/indexer/src/vector/open_ai.rs b/libs/indexer/src/vector/open_ai.rs index 9835aa77..a49e7d72 100644 --- a/libs/indexer/src/vector/open_ai.rs +++ b/libs/indexer/src/vector/open_ai.rs @@ -5,7 +5,6 @@ use appflowy_ai_client::dto::{EmbeddingRequest, OpenAIEmbeddingResponse}; use serde::de::DeserializeOwned; use std::time::Duration; use tiktoken_rs::CoreBPE; -use unicode_segmentation::UnicodeSegmentation; pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings"; @@ -184,53 +183,41 @@ pub fn split_text_by_max_tokens( Ok(chunks) } -#[inline] -pub fn split_text_by_max_content_len( - content: String, +pub fn group_paragraphs_by_max_content_len( + paragraphs: Vec, max_content_len: usize, -) -> Result, AppError> { - if content.is_empty() { - return Ok(vec![]); +) -> Vec { + if paragraphs.is_empty() { + return vec![]; } - if content.len() <= max_content_len { - return Ok(vec![content]); - } - - // Content is longer than max_content_len; need to split - let mut result = Vec::with_capacity(1 + content.len() / max_content_len); - let mut fragment = String::with_capacity(max_content_len); - let mut current_len = 0; - - for grapheme in content.graphemes(true) { - let grapheme_len = grapheme.len(); - if current_len + grapheme_len > max_content_len { - if !fragment.is_empty() { - result.push(std::mem::take(&mut fragment)); - } - current_len = 0; - - if grapheme_len > max_content_len { - // Push the grapheme as a fragment on its own - result.push(grapheme.to_string()); - continue; + let mut result = Vec::new(); + let mut current = String::new(); + for paragraph in paragraphs { + if paragraph.len() + current.len() > max_content_len { + // if we add the paragraph to the current content, it will exceed the limit + // so we push the current content to the result set and start a new chunk + let accumulated = std::mem::replace(&mut current, paragraph); + if !accumulated.is_empty() { + result.push(accumulated); } + } else { + // add the paragraph to the current chunk + current.push_str(¶graph); } - fragment.push_str(grapheme); - current_len += grapheme_len; } - // Add the last fragment if it's not empty - if !fragment.is_empty() { - result.push(fragment); + if !current.is_empty() { + result.push(current); } - Ok(result) + + result } #[cfg(test)] mod tests { - use crate::vector::open_ai::{split_text_by_max_content_len, split_text_by_max_tokens}; + use crate::vector::open_ai::{group_paragraphs_by_max_content_len, split_text_by_max_tokens}; use tiktoken_rs::cl100k_base; #[test] @@ -246,7 +233,7 @@ mod tests { assert!(content.is_char_boundary(content.len())); } - let params = split_text_by_max_content_len(content.clone(), max_tokens).unwrap(); + let params = group_paragraphs_by_max_content_len(vec![content], max_tokens); for content in params { assert!(content.is_char_boundary(0)); assert!(content.is_char_boundary(content.len())); @@ -283,7 +270,7 @@ mod tests { let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); assert_eq!(params.len(), 0); - let params = split_text_by_max_content_len(content.clone(), max_tokens).unwrap(); + let params = group_paragraphs_by_max_content_len(params, max_tokens); assert_eq!(params.len(), 0); } @@ -299,7 +286,7 @@ mod tests { assert_eq!(param, emoji); } - let params = split_text_by_max_content_len(content.clone(), max_tokens).unwrap(); + let params = group_paragraphs_by_max_content_len(params, max_tokens); for (param, emoji) in params.iter().zip(emojis.iter()) { assert_eq!(param, emoji); } @@ -317,7 +304,7 @@ mod tests { let reconstructed_content = params.join(""); assert_eq!(reconstructed_content, content); - let params = split_text_by_max_content_len(content.clone(), max_tokens).unwrap(); + let params = group_paragraphs_by_max_content_len(params, max_tokens); let reconstructed_content: String = params.concat(); assert_eq!(reconstructed_content, content); } @@ -347,7 +334,7 @@ mod tests { let reconstructed_content: String = params.concat(); assert_eq!(reconstructed_content, content); - let params = split_text_by_max_content_len(content.clone(), max_tokens).unwrap(); + let params = group_paragraphs_by_max_content_len(params, max_tokens); let reconstructed_content: String = params.concat(); assert_eq!(reconstructed_content, content); } @@ -365,7 +352,7 @@ mod tests { let reconstructed_content: String = params.concat(); assert_eq!(reconstructed_content, content); - let params = split_text_by_max_content_len(content.clone(), max_tokens).unwrap(); + let params = group_paragraphs_by_max_content_len(params, max_tokens); let reconstructed_content: String = params.concat(); assert_eq!(reconstructed_content, content); } @@ -379,7 +366,7 @@ mod tests { let reconstructed_content: String = params.concat(); assert_eq!(reconstructed_content, content); - let params = split_text_by_max_content_len(content.clone(), max_tokens).unwrap(); + let params = group_paragraphs_by_max_content_len(params, max_tokens); let reconstructed_content: String = params.concat(); assert_eq!(reconstructed_content, content); } @@ -393,7 +380,7 @@ mod tests { let reconstructed_content: String = params.concat(); assert_eq!(reconstructed_content, content); - let params = split_text_by_max_content_len(content.clone(), max_tokens).unwrap(); + let params = group_paragraphs_by_max_content_len(params, max_tokens); let reconstructed_content: String = params.concat(); assert_eq!(reconstructed_content, content); } diff --git a/migrations/20250405092732_af_collab_embeddings_upsert.sql b/migrations/20250405092732_af_collab_embeddings_upsert.sql new file mode 100644 index 00000000..40a0e975 --- /dev/null +++ b/migrations/20250405092732_af_collab_embeddings_upsert.sql @@ -0,0 +1,91 @@ +-- Drop existing primary key if it exists: +ALTER TABLE af_collab_embeddings +DROP CONSTRAINT IF EXISTS af_collab_embeddings_pkey; + +-- Add a new composite primary key on (fragment_id, oid): +-- Currently the fragment_id is generated by hash fragment content, so fragment_id might be +-- conflicting with other fragments, but they are not in the same document. +ALTER TABLE af_collab_embeddings + ADD CONSTRAINT af_collab_embeddings_pkey + PRIMARY KEY (fragment_id, oid); + +CREATE OR REPLACE PROCEDURE af_collab_embeddings_upsert( + IN p_workspace_id UUID, + IN p_oid TEXT, + IN p_tokens_used INT, + IN p_fragments af_fragment_v3[] +) +LANGUAGE plpgsql +AS +$$ +BEGIN +-- Delete all fragments for p_oid that are not present in the new fragment list. +DELETE +FROM af_collab_embeddings +WHERE oid = p_oid + AND fragment_id NOT IN ( + SELECT fragment_id FROM UNNEST(p_fragments) AS f +); + +-- Use MERGE to update existing rows or insert new ones without causing duplicate key errors. +MERGE INTO af_collab_embeddings AS t + USING ( + SELECT + f.fragment_id, + p_oid AS oid, + f.content_type, + f.contents, + f.embedding, + NOW() AS indexed_at, + f.metadata, + f.fragment_index, + f.embedder_type + FROM UNNEST(p_fragments) AS f + ) AS s + ON t.oid = s.oid AND t.fragment_id = s.fragment_id + WHEN MATCHED THEN -- this fragment has not changed + UPDATE SET indexed_at = NOW() + WHEN NOT MATCHED THEN -- this fragment is new + INSERT ( + fragment_id, + oid, + content_type, + content, + embedding, + indexed_at, + metadata, + fragment_index, + embedder_type + ) + VALUES ( + s.fragment_id, + s.oid, + s.content_type, + s.contents, + s.embedding, + NOW(), + s.metadata, + s.fragment_index, + s.embedder_type + ); + +-- Update the usage tracking table with an upsert. +INSERT INTO af_workspace_ai_usage( + created_at, + workspace_id, + search_requests, + search_tokens_consumed, + index_tokens_consumed +) +VALUES ( + NOW()::date, + p_workspace_id, + 0, + 0, + p_tokens_used + ) + ON CONFLICT (created_at, workspace_id) + DO UPDATE SET index_tokens_consumed = af_workspace_ai_usage.index_tokens_consumed + p_tokens_used; + +END +$$; \ No newline at end of file diff --git a/services/appflowy-collaborate/src/group/group_init.rs b/services/appflowy-collaborate/src/group/group_init.rs index 476bc452..aa59d219 100644 --- a/services/appflowy-collaborate/src/group/group_init.rs +++ b/services/appflowy-collaborate/src/group/group_init.rs @@ -1081,8 +1081,6 @@ impl CollabPersister { // persisted one in the database self.save_attempt(&mut snapshot.collab, message_id).await?; } - } else { - tracing::trace!("collab {} state has not changed", self.object_id); } Ok(()) } @@ -1112,9 +1110,7 @@ impl CollabPersister { match self.collab_type { CollabType::Document => { let txn = collab.transact(); - if let Some(text) = DocumentBody::from_collab(collab) - .and_then(|body| body.to_plain_text(txn, false, true).ok()) - { + if let Some(text) = DocumentBody::from_collab(collab).map(|body| body.paragraphs(txn)) { self.index_collab_content(text); } }, @@ -1166,12 +1162,12 @@ impl CollabPersister { Ok(()) } - fn index_collab_content(&self, text: String) { + fn index_collab_content(&self, paragraphs: Vec) { let indexed_collab = UnindexedCollabTask::new( self.workspace_id, self.object_id, self.collab_type, - UnindexedData::Text(text), + UnindexedData::Paragraphs(paragraphs), ); if let Err(err) = self .indexer_scheduler diff --git a/services/appflowy-collaborate/tests/indexer_test.rs b/services/appflowy-collaborate/tests/indexer_test.rs index 367084e0..195bed0b 100644 --- a/services/appflowy-collaborate/tests/indexer_test.rs +++ b/services/appflowy-collaborate/tests/indexer_test.rs @@ -10,7 +10,7 @@ fn document_plain_text() { let doc = getting_started_document_data().unwrap(); let collab = Collab::new_with_origin(CollabOrigin::Server, "1", vec![], false); let document = Document::create_with_data(collab, doc).unwrap(); - let text = document.to_plain_text(false, true).unwrap(); + let text = document.paragraphs().join(""); let expected = "Welcome to AppFlowy $ Download for macOS, Windows, and Linux link $ $ $ quick start Ask AI powered by advanced AI models: chat, search, write, and much more ✨ ❤\u{fe0f}Love AppFlowy and open source? Follow our latest product updates: Twitter : @appflowy Reddit : r/appflowy Github "; assert_eq!(&text, expected); } @@ -20,7 +20,7 @@ fn document_plain_text_with_nested_blocks() { let doc = get_initial_document_data().unwrap(); let collab = Collab::new_with_origin(CollabOrigin::Server, "1", vec![], false); let document = Document::create_with_data(collab, doc).unwrap(); - let text = document.to_plain_text(false, true).unwrap(); + let text = document.paragraphs().join(""); let expected = "Welcome to AppFlowy! Here are the basics Here is H3 Click anywhere and just start typing. Click Enter to create a new line. Highlight any text, and use the editing menu to style your writing however you like. As soon as you type / a menu will pop up. Select different types of content blocks you can add. Type / followed by /bullet or /num to create a list. Click + New Page button at the bottom of your sidebar to add a new page. Click + next to any page title in the sidebar to quickly add a new subpage, Document , Grid , or Kanban Board . Keyboard shortcuts, markdown, and code block Keyboard shortcuts guide Markdown reference Type /code to insert a code block // This is the main function.\nfn main() {\n // Print text to the console.\n println!(\"Hello World!\");\n} This is a paragraph This is a paragraph Have a question❓ Click ? at the bottom right for help and support. This is a paragraph This is a paragraph Click ? at the bottom right for help and support. Like AppFlowy? Follow us: GitHub Twitter : @appflowy Newsletter "; assert_eq!(&text, expected); } diff --git a/services/appflowy-worker/src/import_worker/worker.rs b/services/appflowy-worker/src/import_worker/worker.rs index 246df11e..2e0416c3 100644 --- a/services/appflowy-worker/src/import_worker/worker.rs +++ b/services/appflowy-worker/src/import_worker/worker.rs @@ -1003,7 +1003,7 @@ async fn process_unzip_file( Ok(bytes) => { if let Err(err) = redis_client .set_ex::, Value>( - encode_collab_key(&w_database_id.to_string()), + encode_collab_key(w_database_id.to_string()), bytes, 2592000, // WorkspaceDatabase => 1 month ) @@ -1186,7 +1186,7 @@ async fn process_unzip_file( }); if result.is_err() { - let _: RedisResult = redis_client.del(encode_collab_key(&w_database_id)).await; + let _: RedisResult = redis_client.del(encode_collab_key(w_database_id)).await; let _: RedisResult = redis_client .del(encode_collab_key(&import_task.workspace_id)) .await; diff --git a/services/appflowy-worker/src/indexer_worker/worker.rs b/services/appflowy-worker/src/indexer_worker/worker.rs index 34a66258..f25c6f8b 100644 --- a/services/appflowy-worker/src/indexer_worker/worker.rs +++ b/services/appflowy-worker/src/indexer_worker/worker.rs @@ -1,6 +1,6 @@ use app_error::AppError; -use database::index::get_collabs_indexed_at; -use indexer::collab_indexer::{Indexer, IndexerProvider}; +use database::index::{get_collab_embedding_fragment_ids, get_collabs_indexed_at}; +use indexer::collab_indexer::IndexerProvider; use indexer::entity::EmbeddingRecord; use indexer::error::IndexerError; use indexer::metrics::EmbeddingMetrics; @@ -12,7 +12,6 @@ use indexer::scheduler::{spawn_pg_write_embeddings, UnindexedCollabTask, Unindex use indexer::thread_pool::ThreadPoolNoAbort; use indexer::vector::embedder::Embedder; use indexer::vector::open_ai; -use rayon::prelude::*; use redis::aio::ConnectionManager; use secrecy::{ExposeSecret, Secret}; use sqlx::PgPool; @@ -20,8 +19,9 @@ use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; use tokio::sync::RwLock; +use tokio::task::JoinSet; use tokio::time::{interval, MissedTickBehavior}; -use tracing::{error, info, trace}; +use tracing::{error, info, trace, warn}; pub struct BackgroundIndexerConfig { pub enable: bool, @@ -134,7 +134,7 @@ async fn process_upcoming_tasks( let collab_ids: Vec<_> = tasks.iter().map(|task| task.object_id).collect(); - let indexed_collabs = get_collabs_indexed_at(&pg_pool, collab_ids) + let indexed_collabs = get_collabs_indexed_at(&pg_pool, collab_ids.clone()) .await .unwrap_or_default(); @@ -154,36 +154,78 @@ async fn process_upcoming_tasks( let start = Instant::now(); let num_tasks = tasks.len(); - tasks.into_par_iter().for_each(|task| { - let result = threads.install(|| { - if let Some(indexer) = indexer_provider.indexer_for(task.collab_type) { - let embedder = create_embedder(&config); - let result = handle_task(embedder, indexer, task); - match result { - None => metrics.record_failed_embed_count(1), - Some(record) => { - metrics.record_embed_count(1); - trace!( - "[Background Embedding] send {} embedding record to write task", - record.object_id - ); - if let Err(err) = sender.send(record) { - trace!( - "[Background Embedding] failed to send embedding record to write task: {:?}", - err - ); - } - }, + let existing_embeddings = get_collab_embedding_fragment_ids(&pg_pool, collab_ids) + .await + .unwrap_or_default(); + let mut join_set = JoinSet::new(); + for task in tasks { + if let Some(indexer) = indexer_provider.indexer_for(task.collab_type) { + let embedder = create_embedder(&config); + trace!( + "[Background Embedding] processing task: {}, content:{:?}, collab_type: {}", + task.object_id, + task.data, + task.collab_type + ); + let paragraphs = match task.data { + UnindexedData::Paragraphs(paragraphs) => paragraphs, + UnindexedData::Text(text) => text.split('\n').map(|s| s.to_string()).collect(), + }; + let mut chunks = match indexer.create_embedded_chunks_from_text( + task.object_id, + paragraphs, + embedder.model(), + ) { + Ok(chunks) => chunks, + Err(err) => { + warn!( + "[Background Embedding] failed to create embedded chunks for task: {}, error: {:?}", + task.object_id, + err + ); + continue; + }, + }; + if let Some(existing_chunks) = existing_embeddings.get(&task.object_id) { + for chunk in chunks.iter_mut() { + if existing_chunks.contains(&chunk.fragment_id) { + chunk.content = None; // Clear content to mark unchanged chunk + chunk.embedding = None; + } } } - }); - if let Err(err) = result { - error!( - "[Background Embedding] Failed to process embedder task: {:?}", - err - ); + join_set.spawn(async move { + let embeddings = indexer.embed(&embedder, chunks).await.ok()?; + embeddings.map(|embeddings| EmbeddingRecord { + workspace_id: task.workspace_id, + object_id: task.object_id, + collab_type: task.collab_type, + tokens_used: embeddings.tokens_consumed, + contents: embeddings.params, + }) + }); } - }); + } + + while let Some(Ok(result)) = join_set.join_next().await { + match result { + None => metrics.record_failed_embed_count(1), + Some(record) => { + metrics.record_embed_count(1); + trace!( + "[Background Embedding] send {} embedding record to write task", + record.object_id + ); + if let Err(err) = sender.send(record) { + trace!( + "[Background Embedding] failed to send embedding record to write task: {:?}", + err + ); + } + }, + } + } + let cost = start.elapsed().as_millis(); metrics.record_gen_embedding_time(num_tasks as u32, cost); } @@ -212,32 +254,6 @@ async fn process_upcoming_tasks( } } -fn handle_task( - embedder: Embedder, - indexer: Arc, - task: UnindexedCollabTask, -) -> Option { - trace!( - "[Background Embedding] processing task: {}, content:{:?}, collab_type: {}", - task.object_id, - task.data, - task.collab_type - ); - let chunks = match task.data { - UnindexedData::Text(text) => indexer - .create_embedded_chunks_from_text(task.object_id.clone(), text, embedder.model()) - .ok()?, - }; - let embeddings = indexer.embed(&embedder, chunks).ok()?; - embeddings.map(|embeddings| EmbeddingRecord { - workspace_id: task.workspace_id, - object_id: task.object_id, - collab_type: task.collab_type, - tokens_used: embeddings.tokens_consumed, - contents: embeddings.params, - }) -} - fn create_embedder(config: &BackgroundIndexerConfig) -> Embedder { Embedder::OpenAI(open_ai::Embedder::new( config.open_api_key.expose_secret().clone(), diff --git a/src/api/workspace.rs b/src/api/workspace.rs index e49b2a11..8bcb043f 100644 --- a/src/api/workspace.rs +++ b/src/api/workspace.rs @@ -856,12 +856,12 @@ async fn create_collab_handler( .can_index_workspace(&workspace_id) .await? { - if let Ok(text) = Document::open(collab).and_then(|doc| doc.to_plain_text(false, true)) { + if let Ok(paragraphs) = Document::open(collab).map(|doc| doc.paragraphs()) { let pending = UnindexedCollabTask::new( workspace_id, params.object_id, params.collab_type, - UnindexedData::Text(text), + UnindexedData::Paragraphs(paragraphs), ); state .indexer_scheduler @@ -958,8 +958,7 @@ async fn batch_create_collab_handler( Ok(_) => { match params.collab_type { CollabType::Document => { - let index_text = - Document::open(collab).and_then(|doc| doc.to_plain_text(false, true)); + let index_text = Document::open(collab).map(|doc| doc.paragraphs()); Some((Some(index_text), params)) }, _ => { @@ -1010,12 +1009,12 @@ async fn batch_create_collab_handler( .flat_map(|value| match std::mem::take(&mut value.0) { None => None, Some(text) => text - .map(|text| { + .map(|paragraphs| { UnindexedCollabTask::new( workspace_id, value.1.object_id, value.1.collab_type, - UnindexedData::Text(text), + UnindexedData::Paragraphs(paragraphs), ) }) .ok(), @@ -1826,16 +1825,18 @@ async fn update_collab_handler( )) })?; - if let Ok(text) = Document::open(collab).and_then(|doc| doc.to_plain_text(false, true)) { - let pending = UnindexedCollabTask::new( - workspace_id, - params.object_id, - params.collab_type, - UnindexedData::Text(text), - ); - state - .indexer_scheduler - .index_pending_collab_one(pending, true)?; + if let Ok(paragraphs) = Document::open(collab).map(|doc| doc.paragraphs()) { + if !paragraphs.is_empty() { + let pending = UnindexedCollabTask::new( + workspace_id, + params.object_id, + params.collab_type, + UnindexedData::Paragraphs(paragraphs), + ); + state + .indexer_scheduler + .index_pending_collab_one(pending, true)?; + } } }, _ => { diff --git a/src/biz/collab/ops.rs b/src/biz/collab/ops.rs index d8dcb99d..d49d9e9c 100644 --- a/src/biz/collab/ops.rs +++ b/src/biz/collab/ops.rs @@ -1007,12 +1007,7 @@ fn fill_in_db_row_doc( })?; let doc = Document::open(doc_collab) .map_err(|err| AppError::Internal(anyhow::anyhow!("Failed to open document: {:?}", err)))?; - let plain_text = doc.to_plain_text(true, false).map_err(|err| { - AppError::Internal(anyhow::anyhow!( - "Failed to convert document to plain text: {:?}", - err - )) - })?; + let plain_text = doc.paragraphs().join(""); row_detail.doc = Some(plain_text); Ok(()) } diff --git a/src/biz/collab/utils.rs b/src/biz/collab/utils.rs index b7978657..240bf0e5 100644 --- a/src/biz/collab/utils.rs +++ b/src/biz/collab/utils.rs @@ -609,9 +609,7 @@ pub async fn get_database_row_doc_changes( .map_err(|e| AppError::Internal(anyhow::anyhow!("Failed to create document: {:?}", e)))?; // if the document content is the same, there is no need to update - if cur_doc.to_plain_text(false, false).unwrap_or_default() - == new_doc.to_plain_text(false, false).unwrap_or_default() - { + if cur_doc.paragraphs() == new_doc.paragraphs() { return Ok(None); }; diff --git a/tests/ai_test/chat_with_selected_doc_test.rs b/tests/ai_test/chat_with_selected_doc_test.rs index b9cf2293..9e625cdc 100644 --- a/tests/ai_test/chat_with_selected_doc_test.rs +++ b/tests/ai_test/chat_with_selected_doc_test.rs @@ -257,7 +257,7 @@ Overall, Alex balances his work as a software programmer with his passion for sp // Simulate insert new content let contents = alex_banker_story(); editor.insert_paragraphs(contents.into_iter().map(|s| s.to_string()).collect()); - let text = editor.document.to_plain_text(false, false).unwrap(); + let text = editor.document.paragraphs().join(""); let expected = alex_banker_story().join(""); assert_eq!(text, expected); diff --git a/tests/collab/collab_embedding_test.rs b/tests/collab/collab_embedding_test.rs index f9607eac..6833c32f 100644 --- a/tests/collab/collab_embedding_test.rs +++ b/tests/collab/collab_embedding_test.rs @@ -90,13 +90,13 @@ async fn document_full_sync_then_search_test() { let remote_document = test_client .create_document_collab(workspace_id, object_id) .await; - let remote_plain_text = remote_document.to_plain_text(false, false).unwrap(); - let local_plain_text = local_document.document.to_plain_text(false, false).unwrap(); + let remote_plain_text = remote_document.paragraphs().join(""); + let local_plain_text = local_document.document.paragraphs().join(""); assert_eq!(local_plain_text, remote_plain_text); let search_result = test_client .wait_unit_get_search_result(&workspace_id, "workflows", 1) .await; assert_eq!(search_result.len(), 1); - assert_eq!(search_result[0].preview, Some("AppFlowy is an open-source project. It is an alternative to tools like Notion. AppFlowy provides full control of your data. The project is built using Flutter for the frontend. Rust powers AppFlowy's ".to_string())); + assert_eq!(search_result[0].preview, Some("AppFlowy is an open-source project.It is an alternative to tools like Notion.AppFlowy provides full control of your data.The project is built using Flutter for the frontend.Rust powers AppFlowy's back".to_string())); } diff --git a/tests/collab/database_crud.rs b/tests/collab/database_crud.rs index 80966fcf..a243ec94 100644 --- a/tests/collab/database_crud.rs +++ b/tests/collab/database_crud.rs @@ -35,7 +35,7 @@ async fn database_row_upsert_with_doc() { assert!(row_detail.has_doc); assert_eq!( row_detail.doc, - Some(String::from("\nThis is a document of a database row")) + Some(String::from("This is a document of a database row")) ); } // Upsert row with another doc @@ -57,7 +57,7 @@ async fn database_row_upsert_with_doc() { .unwrap()[0]; assert_eq!( row_detail.doc, - Some(String::from("\nThis is a another document")) + Some(String::from("This is a another document")) ); } } @@ -135,7 +135,7 @@ async fn database_row_upsert() { assert!(row_detail.has_doc); assert_eq!( row_detail.doc, - Some("\nThis is a document of a database row".to_string()) + Some("This is a document of a database row".to_string()) ); } } @@ -327,6 +327,6 @@ async fn database_insert_row_with_doc() { assert!(row_detail.has_doc); assert_eq!( row_detail.doc, - Some("\nThis is a document of a database row".to_string()) + Some("This is a document of a database row".to_string()) ); }