feat: Workspace token usage (#584)

* feat: register open ai token usage during indexing

* feat: register open ai token usage during search

* chore: fixed open ai token usage when searching for documents
This commit is contained in:
Bartosz Sypytkowski 2024-05-29 10:07:56 +02:00 committed by GitHub
parent a6bcbd583f
commit c4702bbbdf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 94 additions and 25 deletions

View file

@ -2,6 +2,7 @@ use std::ops::DerefMut;
use pgvector::Vector;
use sqlx::Transaction;
use uuid::Uuid;
use database_entity::dto::AFCollabEmbeddingParams;
@ -20,8 +21,20 @@ pub async fn has_collab_embeddings(
pub async fn upsert_collab_embeddings(
tx: &mut Transaction<'_, sqlx::Postgres>,
workspace_id: &Uuid,
tokens_used: u32,
records: Vec<AFCollabEmbeddingParams>,
) -> Result<(), sqlx::Error> {
if tokens_used > 0 {
sqlx::query(
"UPDATE af_workspace SET index_token_usage = index_token_usage + $2 WHERE workspace_id = $1",
)
.bind(workspace_id)
.bind(tokens_used as i64)
.execute(tx.deref_mut())
.await?;
}
for r in records {
sqlx::query(
r#"INSERT INTO af_collab_embeddings (fragment_id, oid, partition_key, content_type, content, embedding, indexed_at)

View file

@ -8,13 +8,21 @@ use uuid::Uuid;
pub async fn search_documents(
tx: &mut Transaction<'_, sqlx::Postgres>,
params: SearchDocumentParams,
tokens_used: u32,
) -> Result<Vec<SearchDocumentItem>, sqlx::Error> {
let query = sqlx::query_as::<_, SearchDocumentItem>(
r#"
WITH workspace AS (
UPDATE af_workspace
SET search_token_usage = search_token_usage + $6
WHERE workspace_id = $2
RETURNING workspace_id
)
SELECT
em.oid AS object_id,
collab.workspace_id,
em.partition_key AS collab_type,
em.content_type,
LEFT(em.content, $4) AS content_preview,
u.name AS created_by,
collab.created_at AS created_at,
@ -32,7 +40,8 @@ pub async fn search_documents(
.bind(params.workspace_id)
.bind(Vector::from(params.embedding))
.bind(params.preview)
.bind(params.limit);
.bind(params.limit)
.bind(tokens_used as i64);
let rows = query.fetch_all(tx.deref_mut()).await?;
Ok(rows)
}

View file

@ -0,0 +1,3 @@
-- Add migration script here
ALTER TABLE af_workspace ADD COLUMN search_token_usage BIGINT NOT NULL DEFAULT 0;
ALTER TABLE af_workspace ADD COLUMN index_token_usage BIGINT NOT NULL DEFAULT 0;

View file

@ -16,6 +16,7 @@ use tokio::task::JoinSet;
use tokio::time::interval;
use tokio_util::sync::CancellationToken;
use tracing::instrument;
use uuid::Uuid;
use collab_stream::client::CollabRedisStream;
use collab_stream::model::{CollabUpdateEvent, StreamMessage};
@ -75,6 +76,8 @@ impl CollabHandle {
if !messages.is_empty() {
Self::handle_collab_updates(&mut update_stream, content.get_collab(), messages).await?;
}
let workspace_id =
Uuid::parse_str(&workspace_id).map_err(|e| crate::error::Error::InvalidWorkspace(e))?;
let mut tasks = JoinSet::new();
tasks.spawn(Self::receive_collab_updates(
@ -108,7 +111,7 @@ impl CollabHandle {
mut update_stream: StreamGroup,
content: Weak<dyn Indexable>,
object_id: String,
workspace_id: String,
workspace_id: Uuid,
ingest_interval: Duration,
closing: CancellationToken,
) {
@ -175,7 +178,7 @@ impl CollabHandle {
mut updates: Pin<Box<dyn Stream<Item = FragmentUpdate> + Send + Sync>>,
indexer: Arc<dyn Indexer>,
object_id: String,
workspace_id: String,
workspace_id: Uuid,
ingest_interval: Duration,
token: CancellationToken,
) {
@ -186,14 +189,14 @@ impl CollabHandle {
loop {
select! {
_ = interval.tick() => {
match Self::publish_updates(&indexer, &mut inserts, &mut removals).await {
match Self::publish_updates(&indexer, &workspace_id, &mut inserts, &mut removals).await {
Ok(_) => last_update = Instant::now(),
Err(err) => tracing::error!("document {}/{} watcher failed to publish fragment updates: {}", workspace_id, object_id, err),
}
}
_ = token.cancelled() => {
tracing::trace!("document {}/{} watcher closing signal received, flushing remaining updates", workspace_id, object_id);
if let Err(err) = Self::publish_updates(&indexer, &mut inserts, &mut removals).await {
if let Err(err) = Self::publish_updates(&indexer, &workspace_id, &mut inserts, &mut removals).await {
tracing::error!("document {}/{} watcher failed to publish fragment updates: {}", workspace_id, object_id, err);
}
return;
@ -215,7 +218,7 @@ impl CollabHandle {
let now = Instant::now();
if now.duration_since(last_update) > ingest_interval {
match Self::publish_updates(&indexer, &mut inserts, &mut removals).await {
match Self::publish_updates(&indexer, &workspace_id, &mut inserts, &mut removals).await {
Ok(_) => last_update = now,
Err(err) => tracing::error!("document {}/{} watcher failed to publish fragment updates: {}", workspace_id, object_id, err),
}
@ -227,6 +230,7 @@ impl CollabHandle {
async fn publish_updates(
indexer: &Arc<dyn Indexer>,
workspace_id: &Uuid,
inserts: &mut HashMap<FragmentID, Fragment>,
removals: &mut HashSet<FragmentID>,
) -> Result<()> {
@ -236,7 +240,7 @@ impl CollabHandle {
let inserts: Vec<_> = inserts.drain().map(|(_, doc)| doc).collect();
if !inserts.is_empty() {
tracing::info!("updating indexes for {} fragments", inserts.len());
indexer.update_index(inserts).await?;
indexer.update_index(workspace_id, inserts).await?;
}
if !removals.is_empty() {
@ -346,5 +350,14 @@ mod test {
.collect::<HashSet<_>>();
assert_eq!(contents.len(), 1);
let tokens: i64 =
sqlx::query("SELECT index_token_usage from af_workspace WHERE workspace_id = $1")
.bind(&workspace_id)
.fetch_one(&db)
.await
.unwrap()
.get(0);
assert_ne!(tokens, 0);
}
}

View file

@ -14,6 +14,8 @@ pub enum Error {
Sql(#[from] sqlx::Error),
#[error("OpenAI failed to process request: {0}")]
OpenAI(String),
#[error("invalid workspace ID: {0}")]
InvalidWorkspace(uuid::Error),
}
pub type Result<T, E = Error> = std::result::Result<T, E>;

View file

@ -7,6 +7,7 @@ use openai_dive::v1::resources::embedding::{
};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use uuid::Uuid;
use database::index::{has_collab_embeddings, remove_collab_embeddings, upsert_collab_embeddings};
use database_entity::dto::{AFCollabEmbeddingParams, EmbeddingContentType};
@ -17,7 +18,7 @@ use crate::error::Result;
pub trait Indexer: Send + Sync {
/// Check if document with given id has been already a corresponding index entry.
async fn was_indexed(&self, object_id: &str) -> Result<bool>;
async fn update_index(&self, documents: Vec<Fragment>) -> Result<()>;
async fn update_index(&self, workspace_id: &Uuid, documents: Vec<Fragment>) -> Result<()>;
async fn remove(&self, ids: &[FragmentID]) -> Result<()>;
}
@ -99,7 +100,7 @@ impl PostgresIndexer {
Self { openai, db }
}
async fn get_embeddings(&self, fragments: Vec<Fragment>) -> Result<Vec<EmbedFragment>> {
async fn get_embeddings(&self, fragments: Vec<Fragment>) -> Result<Embeddings> {
let inputs: Vec<_> = fragments
.iter()
.map(|fragment| fragment.content.clone())
@ -118,10 +119,12 @@ impl PostgresIndexer {
.map_err(|e| crate::error::Error::OpenAI(e.to_string()))?;
tracing::trace!("fetched {} embeddings", resp.data.len());
if let Some(usage) = resp.usage {
tracing::info!("OpenAI API usage: {}", usage.total_tokens);
//TODO: report usage statistics
}
let tokens_used = if let Some(usage) = resp.usage {
tracing::info!("OpenAI API index tokens used: {}", usage.total_tokens);
usage.total_tokens
} else {
0
};
let mut fragments: Vec<_> = fragments.into_iter().map(EmbedFragment::from).collect();
for e in resp.data.into_iter() {
@ -135,18 +138,27 @@ impl PostgresIndexer {
};
fragments[e.index as usize].embedding = Some(embedding);
}
Ok(fragments)
Ok(Embeddings {
tokens_used,
fragments,
})
}
async fn store_embeddings(&self, fragments: Vec<EmbedFragment>) -> Result<()> {
async fn store_embeddings(&self, workspace_id: &Uuid, embeddings: Embeddings) -> Result<()> {
tracing::trace!(
"storing {} embeddings inside of vector database",
fragments.len()
embeddings.fragments.len()
);
let mut tx = self.db.begin().await?;
upsert_collab_embeddings(
&mut tx,
fragments.into_iter().map(EmbedFragment::into).collect(),
workspace_id,
embeddings.tokens_used,
embeddings
.fragments
.into_iter()
.map(EmbedFragment::into)
.collect(),
)
.await?;
tx.commit().await?;
@ -154,6 +166,11 @@ impl PostgresIndexer {
}
}
struct Embeddings {
tokens_used: u32,
fragments: Vec<EmbedFragment>,
}
#[async_trait]
impl Indexer for PostgresIndexer {
async fn was_indexed(&self, object_id: &str) -> Result<bool> {
@ -161,9 +178,9 @@ impl Indexer for PostgresIndexer {
Ok(found)
}
async fn update_index(&self, documents: Vec<Fragment>) -> Result<()> {
async fn update_index(&self, workspace_id: &Uuid, documents: Vec<Fragment>) -> Result<()> {
let embeddings = self.get_embeddings(documents).await?;
self.store_embeddings(embeddings).await?;
self.store_embeddings(workspace_id, embeddings).await?;
Ok(())
}
@ -191,7 +208,7 @@ mod test {
let db = db_pool().await;
let object_id = uuid::Uuid::new_v4();
let uid = rand::random();
setup_collab(&db, uid, object_id, vec![]).await;
let workspace_id = setup_collab(&db, uid, object_id, vec![]).await;
let openai = openai_client();
@ -209,10 +226,13 @@ mod test {
// resolve embeddings from OpenAI
let embeddings = indexer.get_embeddings(fragments).await.unwrap();
assert_eq!(embeddings[0].embedding.is_some(), true);
assert_eq!(embeddings.fragments[0].embedding.is_some(), true);
// store embeddings in DB
indexer.store_embeddings(embeddings).await.unwrap();
indexer
.store_embeddings(&workspace_id, embeddings)
.await
.unwrap();
// search for embedding
let mut tx = indexer.db.begin().await.unwrap();

View file

@ -30,9 +30,16 @@ pub async fn search_document(
.await
.map_err(|e| AppResponseError::new(ErrorCode::Internal, e.to_string()))?;
if let Some(usage) = embeddings.usage {
tracing::info!("OpenAI API usage: {}", usage.total_tokens)
}
let tokens_used = if let Some(usage) = embeddings.usage {
tracing::info!(
"workspace {} OpenAI API search tokens used: {}",
workspace_id,
usage.total_tokens
);
usage.total_tokens
} else {
0
};
let embedding = embeddings
.data
@ -61,8 +68,10 @@ pub async fn search_document(
preview: request.preview_size.unwrap_or(180) as i32,
embedding,
},
tokens_used,
)
.await?;
tx.commit().await?;
Ok(
results
.into_iter()