mirror of
https://github.com/AppFlowy-IO/AppFlowy-Cloud.git
synced 2025-04-19 03:24:42 -04:00
feat: Workspace token usage (#584)
* feat: register open ai token usage during indexing * feat: register open ai token usage during search * chore: fixed open ai token usage when searching for documents
This commit is contained in:
parent
a6bcbd583f
commit
c4702bbbdf
7 changed files with 94 additions and 25 deletions
|
@ -2,6 +2,7 @@ use std::ops::DerefMut;
|
|||
|
||||
use pgvector::Vector;
|
||||
use sqlx::Transaction;
|
||||
use uuid::Uuid;
|
||||
|
||||
use database_entity::dto::AFCollabEmbeddingParams;
|
||||
|
||||
|
@ -20,8 +21,20 @@ pub async fn has_collab_embeddings(
|
|||
|
||||
pub async fn upsert_collab_embeddings(
|
||||
tx: &mut Transaction<'_, sqlx::Postgres>,
|
||||
workspace_id: &Uuid,
|
||||
tokens_used: u32,
|
||||
records: Vec<AFCollabEmbeddingParams>,
|
||||
) -> Result<(), sqlx::Error> {
|
||||
if tokens_used > 0 {
|
||||
sqlx::query(
|
||||
"UPDATE af_workspace SET index_token_usage = index_token_usage + $2 WHERE workspace_id = $1",
|
||||
)
|
||||
.bind(workspace_id)
|
||||
.bind(tokens_used as i64)
|
||||
.execute(tx.deref_mut())
|
||||
.await?;
|
||||
}
|
||||
|
||||
for r in records {
|
||||
sqlx::query(
|
||||
r#"INSERT INTO af_collab_embeddings (fragment_id, oid, partition_key, content_type, content, embedding, indexed_at)
|
||||
|
|
|
@ -8,13 +8,21 @@ use uuid::Uuid;
|
|||
pub async fn search_documents(
|
||||
tx: &mut Transaction<'_, sqlx::Postgres>,
|
||||
params: SearchDocumentParams,
|
||||
tokens_used: u32,
|
||||
) -> Result<Vec<SearchDocumentItem>, sqlx::Error> {
|
||||
let query = sqlx::query_as::<_, SearchDocumentItem>(
|
||||
r#"
|
||||
WITH workspace AS (
|
||||
UPDATE af_workspace
|
||||
SET search_token_usage = search_token_usage + $6
|
||||
WHERE workspace_id = $2
|
||||
RETURNING workspace_id
|
||||
)
|
||||
SELECT
|
||||
em.oid AS object_id,
|
||||
collab.workspace_id,
|
||||
em.partition_key AS collab_type,
|
||||
em.content_type,
|
||||
LEFT(em.content, $4) AS content_preview,
|
||||
u.name AS created_by,
|
||||
collab.created_at AS created_at,
|
||||
|
@ -32,7 +40,8 @@ pub async fn search_documents(
|
|||
.bind(params.workspace_id)
|
||||
.bind(Vector::from(params.embedding))
|
||||
.bind(params.preview)
|
||||
.bind(params.limit);
|
||||
.bind(params.limit)
|
||||
.bind(tokens_used as i64);
|
||||
let rows = query.fetch_all(tx.deref_mut()).await?;
|
||||
Ok(rows)
|
||||
}
|
||||
|
|
3
migrations/20240529054858_workspace_add_token_usage.sql
Normal file
3
migrations/20240529054858_workspace_add_token_usage.sql
Normal file
|
@ -0,0 +1,3 @@
|
|||
-- Add migration script here
|
||||
ALTER TABLE af_workspace ADD COLUMN search_token_usage BIGINT NOT NULL DEFAULT 0;
|
||||
ALTER TABLE af_workspace ADD COLUMN index_token_usage BIGINT NOT NULL DEFAULT 0;
|
|
@ -16,6 +16,7 @@ use tokio::task::JoinSet;
|
|||
use tokio::time::interval;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::instrument;
|
||||
use uuid::Uuid;
|
||||
|
||||
use collab_stream::client::CollabRedisStream;
|
||||
use collab_stream::model::{CollabUpdateEvent, StreamMessage};
|
||||
|
@ -75,6 +76,8 @@ impl CollabHandle {
|
|||
if !messages.is_empty() {
|
||||
Self::handle_collab_updates(&mut update_stream, content.get_collab(), messages).await?;
|
||||
}
|
||||
let workspace_id =
|
||||
Uuid::parse_str(&workspace_id).map_err(|e| crate::error::Error::InvalidWorkspace(e))?;
|
||||
|
||||
let mut tasks = JoinSet::new();
|
||||
tasks.spawn(Self::receive_collab_updates(
|
||||
|
@ -108,7 +111,7 @@ impl CollabHandle {
|
|||
mut update_stream: StreamGroup,
|
||||
content: Weak<dyn Indexable>,
|
||||
object_id: String,
|
||||
workspace_id: String,
|
||||
workspace_id: Uuid,
|
||||
ingest_interval: Duration,
|
||||
closing: CancellationToken,
|
||||
) {
|
||||
|
@ -175,7 +178,7 @@ impl CollabHandle {
|
|||
mut updates: Pin<Box<dyn Stream<Item = FragmentUpdate> + Send + Sync>>,
|
||||
indexer: Arc<dyn Indexer>,
|
||||
object_id: String,
|
||||
workspace_id: String,
|
||||
workspace_id: Uuid,
|
||||
ingest_interval: Duration,
|
||||
token: CancellationToken,
|
||||
) {
|
||||
|
@ -186,14 +189,14 @@ impl CollabHandle {
|
|||
loop {
|
||||
select! {
|
||||
_ = interval.tick() => {
|
||||
match Self::publish_updates(&indexer, &mut inserts, &mut removals).await {
|
||||
match Self::publish_updates(&indexer, &workspace_id, &mut inserts, &mut removals).await {
|
||||
Ok(_) => last_update = Instant::now(),
|
||||
Err(err) => tracing::error!("document {}/{} watcher failed to publish fragment updates: {}", workspace_id, object_id, err),
|
||||
}
|
||||
}
|
||||
_ = token.cancelled() => {
|
||||
tracing::trace!("document {}/{} watcher closing signal received, flushing remaining updates", workspace_id, object_id);
|
||||
if let Err(err) = Self::publish_updates(&indexer, &mut inserts, &mut removals).await {
|
||||
if let Err(err) = Self::publish_updates(&indexer, &workspace_id, &mut inserts, &mut removals).await {
|
||||
tracing::error!("document {}/{} watcher failed to publish fragment updates: {}", workspace_id, object_id, err);
|
||||
}
|
||||
return;
|
||||
|
@ -215,7 +218,7 @@ impl CollabHandle {
|
|||
|
||||
let now = Instant::now();
|
||||
if now.duration_since(last_update) > ingest_interval {
|
||||
match Self::publish_updates(&indexer, &mut inserts, &mut removals).await {
|
||||
match Self::publish_updates(&indexer, &workspace_id, &mut inserts, &mut removals).await {
|
||||
Ok(_) => last_update = now,
|
||||
Err(err) => tracing::error!("document {}/{} watcher failed to publish fragment updates: {}", workspace_id, object_id, err),
|
||||
}
|
||||
|
@ -227,6 +230,7 @@ impl CollabHandle {
|
|||
|
||||
async fn publish_updates(
|
||||
indexer: &Arc<dyn Indexer>,
|
||||
workspace_id: &Uuid,
|
||||
inserts: &mut HashMap<FragmentID, Fragment>,
|
||||
removals: &mut HashSet<FragmentID>,
|
||||
) -> Result<()> {
|
||||
|
@ -236,7 +240,7 @@ impl CollabHandle {
|
|||
let inserts: Vec<_> = inserts.drain().map(|(_, doc)| doc).collect();
|
||||
if !inserts.is_empty() {
|
||||
tracing::info!("updating indexes for {} fragments", inserts.len());
|
||||
indexer.update_index(inserts).await?;
|
||||
indexer.update_index(workspace_id, inserts).await?;
|
||||
}
|
||||
|
||||
if !removals.is_empty() {
|
||||
|
@ -346,5 +350,14 @@ mod test {
|
|||
.collect::<HashSet<_>>();
|
||||
|
||||
assert_eq!(contents.len(), 1);
|
||||
|
||||
let tokens: i64 =
|
||||
sqlx::query("SELECT index_token_usage from af_workspace WHERE workspace_id = $1")
|
||||
.bind(&workspace_id)
|
||||
.fetch_one(&db)
|
||||
.await
|
||||
.unwrap()
|
||||
.get(0);
|
||||
assert_ne!(tokens, 0);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,6 +14,8 @@ pub enum Error {
|
|||
Sql(#[from] sqlx::Error),
|
||||
#[error("OpenAI failed to process request: {0}")]
|
||||
OpenAI(String),
|
||||
#[error("invalid workspace ID: {0}")]
|
||||
InvalidWorkspace(uuid::Error),
|
||||
}
|
||||
|
||||
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
||||
|
|
|
@ -7,6 +7,7 @@ use openai_dive::v1::resources::embedding::{
|
|||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::PgPool;
|
||||
use uuid::Uuid;
|
||||
|
||||
use database::index::{has_collab_embeddings, remove_collab_embeddings, upsert_collab_embeddings};
|
||||
use database_entity::dto::{AFCollabEmbeddingParams, EmbeddingContentType};
|
||||
|
@ -17,7 +18,7 @@ use crate::error::Result;
|
|||
pub trait Indexer: Send + Sync {
|
||||
/// Check if document with given id has been already a corresponding index entry.
|
||||
async fn was_indexed(&self, object_id: &str) -> Result<bool>;
|
||||
async fn update_index(&self, documents: Vec<Fragment>) -> Result<()>;
|
||||
async fn update_index(&self, workspace_id: &Uuid, documents: Vec<Fragment>) -> Result<()>;
|
||||
async fn remove(&self, ids: &[FragmentID]) -> Result<()>;
|
||||
}
|
||||
|
||||
|
@ -99,7 +100,7 @@ impl PostgresIndexer {
|
|||
Self { openai, db }
|
||||
}
|
||||
|
||||
async fn get_embeddings(&self, fragments: Vec<Fragment>) -> Result<Vec<EmbedFragment>> {
|
||||
async fn get_embeddings(&self, fragments: Vec<Fragment>) -> Result<Embeddings> {
|
||||
let inputs: Vec<_> = fragments
|
||||
.iter()
|
||||
.map(|fragment| fragment.content.clone())
|
||||
|
@ -118,10 +119,12 @@ impl PostgresIndexer {
|
|||
.map_err(|e| crate::error::Error::OpenAI(e.to_string()))?;
|
||||
|
||||
tracing::trace!("fetched {} embeddings", resp.data.len());
|
||||
if let Some(usage) = resp.usage {
|
||||
tracing::info!("OpenAI API usage: {}", usage.total_tokens);
|
||||
//TODO: report usage statistics
|
||||
}
|
||||
let tokens_used = if let Some(usage) = resp.usage {
|
||||
tracing::info!("OpenAI API index tokens used: {}", usage.total_tokens);
|
||||
usage.total_tokens
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let mut fragments: Vec<_> = fragments.into_iter().map(EmbedFragment::from).collect();
|
||||
for e in resp.data.into_iter() {
|
||||
|
@ -135,18 +138,27 @@ impl PostgresIndexer {
|
|||
};
|
||||
fragments[e.index as usize].embedding = Some(embedding);
|
||||
}
|
||||
Ok(fragments)
|
||||
Ok(Embeddings {
|
||||
tokens_used,
|
||||
fragments,
|
||||
})
|
||||
}
|
||||
|
||||
async fn store_embeddings(&self, fragments: Vec<EmbedFragment>) -> Result<()> {
|
||||
async fn store_embeddings(&self, workspace_id: &Uuid, embeddings: Embeddings) -> Result<()> {
|
||||
tracing::trace!(
|
||||
"storing {} embeddings inside of vector database",
|
||||
fragments.len()
|
||||
embeddings.fragments.len()
|
||||
);
|
||||
let mut tx = self.db.begin().await?;
|
||||
upsert_collab_embeddings(
|
||||
&mut tx,
|
||||
fragments.into_iter().map(EmbedFragment::into).collect(),
|
||||
workspace_id,
|
||||
embeddings.tokens_used,
|
||||
embeddings
|
||||
.fragments
|
||||
.into_iter()
|
||||
.map(EmbedFragment::into)
|
||||
.collect(),
|
||||
)
|
||||
.await?;
|
||||
tx.commit().await?;
|
||||
|
@ -154,6 +166,11 @@ impl PostgresIndexer {
|
|||
}
|
||||
}
|
||||
|
||||
struct Embeddings {
|
||||
tokens_used: u32,
|
||||
fragments: Vec<EmbedFragment>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Indexer for PostgresIndexer {
|
||||
async fn was_indexed(&self, object_id: &str) -> Result<bool> {
|
||||
|
@ -161,9 +178,9 @@ impl Indexer for PostgresIndexer {
|
|||
Ok(found)
|
||||
}
|
||||
|
||||
async fn update_index(&self, documents: Vec<Fragment>) -> Result<()> {
|
||||
async fn update_index(&self, workspace_id: &Uuid, documents: Vec<Fragment>) -> Result<()> {
|
||||
let embeddings = self.get_embeddings(documents).await?;
|
||||
self.store_embeddings(embeddings).await?;
|
||||
self.store_embeddings(workspace_id, embeddings).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -191,7 +208,7 @@ mod test {
|
|||
let db = db_pool().await;
|
||||
let object_id = uuid::Uuid::new_v4();
|
||||
let uid = rand::random();
|
||||
setup_collab(&db, uid, object_id, vec![]).await;
|
||||
let workspace_id = setup_collab(&db, uid, object_id, vec![]).await;
|
||||
|
||||
let openai = openai_client();
|
||||
|
||||
|
@ -209,10 +226,13 @@ mod test {
|
|||
|
||||
// resolve embeddings from OpenAI
|
||||
let embeddings = indexer.get_embeddings(fragments).await.unwrap();
|
||||
assert_eq!(embeddings[0].embedding.is_some(), true);
|
||||
assert_eq!(embeddings.fragments[0].embedding.is_some(), true);
|
||||
|
||||
// store embeddings in DB
|
||||
indexer.store_embeddings(embeddings).await.unwrap();
|
||||
indexer
|
||||
.store_embeddings(&workspace_id, embeddings)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// search for embedding
|
||||
let mut tx = indexer.db.begin().await.unwrap();
|
||||
|
|
|
@ -30,9 +30,16 @@ pub async fn search_document(
|
|||
.await
|
||||
.map_err(|e| AppResponseError::new(ErrorCode::Internal, e.to_string()))?;
|
||||
|
||||
if let Some(usage) = embeddings.usage {
|
||||
tracing::info!("OpenAI API usage: {}", usage.total_tokens)
|
||||
}
|
||||
let tokens_used = if let Some(usage) = embeddings.usage {
|
||||
tracing::info!(
|
||||
"workspace {} OpenAI API search tokens used: {}",
|
||||
workspace_id,
|
||||
usage.total_tokens
|
||||
);
|
||||
usage.total_tokens
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let embedding = embeddings
|
||||
.data
|
||||
|
@ -61,8 +68,10 @@ pub async fn search_document(
|
|||
preview: request.preview_size.unwrap_or(180) as i32,
|
||||
embedding,
|
||||
},
|
||||
tokens_used,
|
||||
)
|
||||
.await?;
|
||||
tx.commit().await?;
|
||||
Ok(
|
||||
results
|
||||
.into_iter()
|
||||
|
|
Loading…
Add table
Reference in a new issue