feat: chat settings (#1044)

* feat: chat settings

* chore: fix sqlx
This commit is contained in:
Nathan.fooo 2024-12-05 23:30:11 +08:00 committed by GitHub
parent 445d3af5fa
commit afcd1130c3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 358 additions and 64 deletions

View file

@ -0,0 +1,34 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT name, meta_data, rag_ids\n FROM af_chat\n WHERE chat_id = $1 AND deleted_at IS NULL\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "name",
"type_info": "Text"
},
{
"ordinal": 1,
"name": "meta_data",
"type_info": "Jsonb"
},
{
"ordinal": 2,
"name": "rag_ids",
"type_info": "Jsonb"
}
],
"parameters": {
"Left": [
"Uuid"
]
},
"nullable": [
false,
false,
false
]
},
"hash": "66218110851919b05b95b008a17547547d23f6baeeff8a5521b2b246126adc34"
}

View file

@ -25,7 +25,7 @@ pub fn load_env() {
});
}
pub fn local_ai_test_enabled() -> bool {
pub fn ai_test_enabled() -> bool {
// In appflowy GitHub CI, we enable 'ai-test-enabled' feature by default, so even if the env var is not set,
// we still enable the local ai test.
if cfg!(feature = "ai-test-enabled") {

View file

@ -13,6 +13,7 @@ use shared_entity::dto::ai_dto::{
CalculateSimilarityParams, RepeatedRelatedQuestion, SimilarityResponse, STREAM_ANSWER_KEY,
STREAM_METADATA_KEY,
};
use shared_entity::dto::chat_dto::{ChatSettings, UpdateChatParams};
use shared_entity::response::{AppResponse, AppResponseError};
use std::pin::Pin;
use std::task::{Context, Poll};
@ -37,6 +38,45 @@ impl Client {
AppResponse::<()>::from_response(resp).await?.into_error()
}
pub async fn update_chat_settings(
&self,
workspace_id: &str,
chat_id: &str,
params: UpdateChatParams,
) -> Result<(), AppResponseError> {
let url = format!(
"{}/api/chat/{workspace_id}/{chat_id}/settings",
self.base_url
);
let resp = self
.http_client_with_auth(Method::POST, &url)
.await?
.json(&params)
.send()
.await?;
log_request_id(&resp);
AppResponse::<()>::from_response(resp).await?.into_error()
}
pub async fn get_chat_settings(
&self,
workspace_id: &str,
chat_id: &str,
) -> Result<ChatSettings, AppResponseError> {
let url = format!(
"{}/api/chat/{workspace_id}/{chat_id}/settings",
self.base_url
);
let resp = self
.http_client_with_auth(Method::GET, &url)
.await?
.send()
.await?;
log_request_id(&resp);
AppResponse::<ChatSettings>::from_response(resp)
.await?
.into_data()
}
/// Delete a chat for given chat_id
pub async fn delete_chat(
&self,
@ -103,10 +143,10 @@ impl Client {
&self,
workspace_id: &str,
chat_id: &str,
question_message_id: i64,
question_id: i64,
) -> Result<QuestionStream, AppResponseError> {
let url = format!(
"{}/api/chat/{workspace_id}/{chat_id}/{question_message_id}/v2/answer/stream",
"{}/api/chat/{workspace_id}/{chat_id}/{question_id}/v2/answer/stream",
self.base_url
);
let resp = self
@ -193,7 +233,10 @@ impl Client {
offset: MessageCursor,
limit: u64,
) -> Result<RepeatedChatMessage, AppResponseError> {
let mut url = format!("{}/api/chat/{workspace_id}/{chat_id}", self.base_url);
let mut url = format!(
"{}/api/chat/{workspace_id}/{chat_id}/message",
self.base_url
);
let mut query_params = vec![("limit", limit.to_string())];
match offset {
MessageCursor::Offset(offset_value) => {

View file

@ -1,12 +1,11 @@
use crate::pg_row::AFChatRow;
use crate::workspace::is_workspace_exist;
use anyhow::anyhow;
use app_error::AppError;
use chrono::{DateTime, Utc};
use shared_entity::dto::chat_dto::{
ChatAuthor, ChatMessage, ChatMessageMetadata, CreateChatParams, GetChatMessageParams,
MessageCursor, RepeatedChatMessage, UpdateChatMessageContentParams, UpdateChatMessageMetaParams,
UpdateChatParams,
ChatAuthor, ChatMessage, ChatMessageMetadata, ChatSettings, CreateChatParams,
GetChatMessageParams, MessageCursor, RepeatedChatMessage, UpdateChatMessageContentParams,
UpdateChatMessageMetaParams, UpdateChatParams,
};
use serde_json::json;
@ -19,19 +18,13 @@ use tracing::warn;
use uuid::Uuid;
pub async fn insert_chat(
txn: &mut Transaction<'_, Postgres>,
pub async fn insert_chat<'a, E: Executor<'a, Database = Postgres>>(
executor: E,
workspace_id: &str,
params: CreateChatParams,
) -> Result<(), AppError> {
let chat_id = Uuid::from_str(&params.chat_id)?;
let workspace_id = Uuid::from_str(workspace_id)?;
if !is_workspace_exist(txn.deref_mut(), &workspace_id).await? {
return Err(AppError::RecordNotFound(format!(
"workspace with given id:{} is not found",
workspace_id
)));
}
let rag_ids = json!(params.rag_ids);
sqlx::query!(
r#"
@ -43,25 +36,40 @@ pub async fn insert_chat(
workspace_id,
rag_ids,
)
.execute(txn.deref_mut())
.execute(executor)
.await
.map_err(|err| AppError::Internal(anyhow!("Failed to insert chat: {}", err)))?;
Ok(())
}
/// Updates specific fields of a chat record in the database using transactional queries.
///
/// This function dynamically builds an SQL `UPDATE` query based on the provided parameters to
/// update fields of a specific chat record identified by `chat_id`. It uses a transaction to ensure
/// that the update operation is atomic.
///
pub async fn update_chat(
txn: &mut Transaction<'_, Postgres>,
pub async fn select_chat_settings<'a, E: Executor<'a, Database = Postgres>>(
executor: E,
chat_id: &Uuid,
) -> Result<ChatSettings, AppError> {
let row = sqlx::query!(
r#"
SELECT name, meta_data, rag_ids
FROM af_chat
WHERE chat_id = $1 AND deleted_at IS NULL
"#,
&chat_id,
)
.fetch_one(executor)
.await?;
let rag_ids = serde_json::from_value::<Vec<String>>(row.rag_ids).unwrap_or_default();
Ok(ChatSettings {
name: row.name,
rag_ids,
metadata: row.meta_data,
})
}
pub async fn update_chat_settings<'a, E: Executor<'a, Database = Postgres>>(
executor: E,
chat_id: &Uuid,
params: UpdateChatParams,
) -> Result<(), AppError> {
let mut query_parts = vec!["UPDATE af_chat SET".to_string()];
let mut query_parts = vec![];
let mut args = PgArguments::default();
let mut current_param_pos = 1; // Start counting SQL parameters from 1
@ -77,7 +85,7 @@ pub async fn update_chat(
}
if let Some(ref metadata) = params.metadata {
query_parts.push(format!("metadata = metadata || ${}", current_param_pos));
query_parts.push(format!("meta_data = meta_data || ${}", current_param_pos));
args
.add(json!(metadata))
.map_err(|err| AppError::SqlxArgEncodingError {
@ -87,7 +95,27 @@ pub async fn update_chat(
current_param_pos += 1;
}
query_parts.push(format!("WHERE chat_id = ${}", current_param_pos));
if let Some(rag_ids) = params.rag_ids {
query_parts.push(format!("rag_ids = ${}", current_param_pos));
args
.add(json!(rag_ids))
.map_err(|err| AppError::SqlxArgEncodingError {
desc: format!("unable to encode rag ids for chat id {}", chat_id),
err,
})?;
current_param_pos += 1;
}
if query_parts.is_empty() {
// If no fields to update, skip execution
return Ok(());
}
let query = format!(
"UPDATE af_chat SET {} WHERE chat_id = ${}",
query_parts.join(", "),
current_param_pos
);
args
.add(chat_id)
.map_err(|err| AppError::SqlxArgEncodingError {
@ -95,9 +123,11 @@ pub async fn update_chat(
err,
})?;
let query = query_parts.join(", ") + ";";
let query = sqlx::query_with(&query, args);
query.execute(txn.deref_mut()).await?;
sqlx::query_with(&query, args)
.execute(executor)
.await
.map_err(|err| AppError::Internal(anyhow!("Failed to update chat settings: {}", err)))?;
Ok(())
}

View file

@ -18,13 +18,13 @@ pub struct CreateChatParams {
#[derive(Debug, Clone, Validate, Serialize, Deserialize)]
pub struct UpdateChatParams {
#[validate(custom = "validate_not_empty_str")]
pub chat_id: String,
#[validate(custom = "validate_not_empty_str")]
pub name: Option<String>,
/// Key-value pairs of metadata to be updated.
pub metadata: Option<serde_json::Value>,
pub rag_ids: Option<Vec<String>>,
}
#[derive(Debug, Clone, Validate, Serialize, Deserialize)]
@ -308,6 +308,14 @@ pub struct RepeatedChatMessage {
pub total: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatSettings {
// Currently we have not used the `name` field in the ChatSettings
pub name: String,
pub rag_ids: Vec<String>,
pub metadata: serde_json::Value,
}
#[derive(Debug, Default, Clone, Serialize_repr, Deserialize_repr)]
#[repr(u8)]
pub enum ChatAuthorType {

View file

@ -1355,7 +1355,7 @@ async fn get_encode_collab_from_bytes(
.map_err(|err| ImportError::Internal(err.into()))?,
)
},
Err(err) => return Err(err.into()),
Err(err) => Err(err.into()),
}
}

View file

@ -6,18 +6,20 @@ use crate::state::AppState;
use actix_web::web::{Data, Json};
use actix_web::{web, HttpRequest, HttpResponse, Scope};
use crate::api::util::ai_model_from_header;
use app_error::AppError;
use appflowy_ai_client::dto::{CreateChatContext, RepeatedRelatedQuestion};
use authentication::jwt::UserUuid;
use bytes::Bytes;
use database::chat;
use futures::Stream;
use futures_util::stream;
use futures_util::{FutureExt, TryStreamExt};
use pin_project::pin_project;
use shared_entity::dto::chat_dto::{
ChatAuthor, ChatMessage, CreateAnswerMessageParams, CreateChatMessageParams,
ChatAuthor, ChatMessage, ChatSettings, CreateAnswerMessageParams, CreateChatMessageParams,
CreateChatMessageParamsV2, CreateChatParams, GetChatMessageParams, MessageCursor,
RepeatedChatMessage, UpdateChatMessageContentParams,
RepeatedChatMessage, UpdateChatMessageContentParams, UpdateChatParams,
};
use shared_entity::response::{AppResponse, JsonAppResponse};
use std::collections::HashMap;
@ -25,17 +27,12 @@ use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::sync::oneshot;
use tokio::task;
use database::chat;
use crate::api::util::ai_model_from_header;
use database::chat::chat_ops::insert_answer_message;
use tracing::{error, instrument, trace};
use uuid::Uuid;
use validator::Validate;
pub fn chat_scope() -> Scope {
web::scope("/api/chat/{workspace_id}")
// Chat management
// Chat CRUD
.service(
web::resource("")
.route(web::post().to(create_chat_handler))
@ -43,13 +40,22 @@ pub fn chat_scope() -> Scope {
.service(
web::resource("/{chat_id}")
.route(web::delete().to(delete_chat_handler))
// Deprecated, use /message instead
.route(web::get().to(get_chat_message_handler))
)
// Settings
.service(
web::resource("/{chat_id}/settings")
.route(web::get().to(get_chat_settings_handler))
.route(web::post().to(update_chat_settings_handler))
)
// Message management
.service(
web::resource("/{chat_id}/message")
.route(web::put().to(update_question_handler))
.route(web::get().to(get_chat_message_handler))
)
.service(
web::resource("/{chat_id}/message/question")
@ -202,7 +208,7 @@ async fn save_answer_handler(
payload.validate().map_err(AppError::from)?;
let (_workspace_id, chat_id) = path.into_inner();
let message = insert_answer_message(
let message = database::chat::chat_ops::insert_answer_message(
&state.pg_pool,
ChatAuthor::ai(),
&chat_id,
@ -343,6 +349,28 @@ async fn get_chat_message_handler(
Ok(AppResponse::Ok().with_data(messages).into())
}
#[instrument(level = "debug", skip_all, err)]
async fn get_chat_settings_handler(
path: web::Path<(String, String)>,
state: Data<AppState>,
) -> actix_web::Result<JsonAppResponse<ChatSettings>> {
let (_, chat_id) = path.into_inner();
let chat_id_uuid = Uuid::parse_str(&chat_id).map_err(AppError::from)?;
let settings = chat::chat_ops::select_chat_settings(&state.pg_pool, &chat_id_uuid).await?;
Ok(AppResponse::Ok().with_data(settings).into())
}
async fn update_chat_settings_handler(
path: web::Path<(String, String)>,
state: Data<AppState>,
payload: Json<UpdateChatParams>,
) -> actix_web::Result<JsonAppResponse<()>> {
let (_workspace_id, chat_id) = path.into_inner();
let chat_id_uuid = Uuid::parse_str(&chat_id).map_err(AppError::from)?;
chat::chat_ops::update_chat_settings(&state.pg_pool, &chat_id_uuid, payload.into_inner()).await?;
Ok(AppResponse::Ok().into())
}
#[pin_project]
pub struct FinalAnswerStream<S, F> {
#[pin]

View file

@ -30,9 +30,7 @@ pub(crate) async fn create_chat(
params.validate()?;
trace!("[Chat] create chat {:?}", params);
let mut txn = pg_pool.begin().await?;
insert_chat(&mut txn, workspace_id, params).await?;
txn.commit().await?;
insert_chat(pg_pool, workspace_id, params).await?;
Ok(())
}

View file

@ -2,16 +2,103 @@ use crate::ai_test::util::read_text_from_asset;
use assert_json_diff::{assert_json_eq, assert_json_include};
use client_api::entity::{QuestionStream, QuestionStreamValue};
use client_api_test::{local_ai_test_enabled, TestClient};
use client_api_test::{ai_test_enabled, TestClient};
use futures_util::StreamExt;
use serde_json::json;
use shared_entity::dto::chat_dto::{
ChatMessageMetadata, ChatRAGData, CreateChatMessageParams, CreateChatParams, MessageCursor,
UpdateChatParams,
};
#[tokio::test]
async fn update_chat_settings_test() {
if !ai_test_enabled() {
return;
}
let test_client = TestClient::new_user_without_ws_conn().await;
let workspace_id = test_client.workspace_id().await;
let chat_id = uuid::Uuid::new_v4().to_string();
let params = CreateChatParams {
chat_id: chat_id.clone(),
name: "my first chat".to_string(),
rag_ids: vec![],
};
test_client
.api_client
.create_chat(&workspace_id, params)
.await
.unwrap();
// Update name and rag_ids
test_client
.api_client
.update_chat_settings(
&workspace_id,
&chat_id,
UpdateChatParams {
name: Some("my second chat".to_string()),
metadata: None,
rag_ids: Some(vec!["rag1".to_string(), "rag2".to_string()]),
},
)
.await
.unwrap();
// Get chat settings and check if the name and rag_ids are updated
let settings = test_client
.api_client
.get_chat_settings(&workspace_id, &chat_id)
.await
.unwrap();
assert_eq!(settings.name, "my second chat");
assert_eq!(
settings.rag_ids,
vec!["rag1".to_string(), "rag2".to_string()]
);
// Update chat metadata
test_client
.api_client
.update_chat_settings(
&workspace_id,
&chat_id,
UpdateChatParams {
name: None,
metadata: Some(json!({"1": "A"})),
rag_ids: None,
},
)
.await
.unwrap();
test_client
.api_client
.update_chat_settings(
&workspace_id,
&chat_id,
UpdateChatParams {
name: None,
metadata: Some(json!({"2": "B"})),
rag_ids: None,
},
)
.await
.unwrap();
// check if the metadata is updated
let settings = test_client
.api_client
.get_chat_settings(&workspace_id, &chat_id)
.await
.unwrap();
assert_eq!(settings.metadata, json!({"1": "A", "2": "B"}));
}
#[tokio::test]
async fn create_chat_and_create_messages_test() {
if !local_ai_test_enabled() {
if !ai_test_enabled() {
return;
}
@ -100,7 +187,7 @@ async fn create_chat_and_create_messages_test() {
#[tokio::test]
async fn chat_qa_test() {
if !local_ai_test_enabled() {
if !ai_test_enabled() {
return;
}
let test_client = TestClient::new_user_without_ws_conn().await;
@ -175,7 +262,7 @@ async fn chat_qa_test() {
#[tokio::test]
async fn generate_chat_message_answer_test() {
if !local_ai_test_enabled() {
if !ai_test_enabled() {
return;
}
let test_client = TestClient::new_user_without_ws_conn().await;
@ -209,7 +296,7 @@ async fn generate_chat_message_answer_test() {
#[tokio::test]
async fn create_chat_context_test() {
if !local_ai_test_enabled() {
if !ai_test_enabled() {
return;
}
let test_client = TestClient::new_user_without_ws_conn().await;

View file

@ -1,10 +1,10 @@
use appflowy_ai_client::dto::{AIModel, CompletionType};
use client_api_test::{local_ai_test_enabled, TestClient};
use client_api_test::{ai_test_enabled, TestClient};
use shared_entity::dto::ai_dto::CompleteTextParams;
#[tokio::test]
async fn improve_writing_test() {
if !local_ai_test_enabled() {
if !ai_test_enabled() {
return;
}
let test_client = TestClient::new_user().await;

View file

@ -1,10 +1,10 @@
use client_api_test::{local_ai_test_enabled, TestClient};
use client_api_test::{ai_test_enabled, TestClient};
use serde_json::json;
use shared_entity::dto::ai_dto::{SummarizeRowData, SummarizeRowParams};
#[tokio::test]
async fn summarize_row_test() {
if !local_ai_test_enabled() {
if !ai_test_enabled() {
return;
}
let test_client = TestClient::new_user().await;

View file

@ -1,12 +1,14 @@
use crate::sql_test::util::{setup_db, test_create_user};
use database::chat::chat_ops::{
delete_chat, get_all_chat_messages, insert_chat, insert_question_message, select_chat,
select_chat_messages,
select_chat_messages, select_chat_settings, update_chat_settings,
};
use serde_json::json;
use shared_entity::dto::chat_dto::{
ChatAuthor, ChatAuthorType, CreateChatParams, GetChatMessageParams,
};
use shared_entity::dto::chat_dto::UpdateChatParams;
use sqlx::PgPool;
#[sqlx::test(migrations = false)]
@ -23,9 +25,8 @@ async fn chat_crud_test(pool: PgPool) {
let chat_id = uuid::Uuid::new_v4().to_string();
// create chat
{
let mut txn = pool.begin().await.unwrap();
insert_chat(
&mut txn,
&pool,
&user.workspace_id,
CreateChatParams {
chat_id: chat_id.clone(),
@ -35,7 +36,6 @@ async fn chat_crud_test(pool: PgPool) {
)
.await
.unwrap();
txn.commit().await.unwrap();
}
// get chat
@ -76,9 +76,8 @@ async fn chat_message_crud_test(pool: PgPool) {
let chat_id = uuid::Uuid::new_v4().to_string();
// create chat
{
let mut txn = pool.begin().await.unwrap();
insert_chat(
&mut txn,
&pool,
&user.workspace_id,
CreateChatParams {
chat_id: chat_id.clone(),
@ -88,7 +87,6 @@ async fn chat_message_crud_test(pool: PgPool) {
)
.await
.unwrap();
txn.commit().await.unwrap();
}
// create chat messages
@ -183,3 +181,71 @@ async fn chat_message_crud_test(pool: PgPool) {
assert_eq!(messages.len(), 5);
}
}
#[sqlx::test(migrations = false)]
async fn chat_setting_test(pool: PgPool) {
setup_db(&pool).await.unwrap();
let user_uuid = uuid::Uuid::new_v4();
let name = user_uuid.to_string();
let email = format!("{}@appflowy.io", name);
let user = test_create_user(&pool, user_uuid, &email, &name)
.await
.unwrap();
let workspace_id = user.workspace_id;
let chat_id = uuid::Uuid::new_v4();
// Insert initial chat data with rag_ids
let insert_params = CreateChatParams {
chat_id: chat_id.to_string(),
name: "Initial Chat".to_string(),
rag_ids: vec!["rag1".to_string(), "rag2".to_string()],
};
insert_chat(&pool, &workspace_id, insert_params)
.await
.expect("Failed to insert chat");
// Validate inserted rag_ids
let settings = select_chat_settings(&pool, &chat_id)
.await
.expect("Failed to get chat settings");
assert_eq!(settings.rag_ids, vec!["rag1", "rag2"]);
// Update metadata
let update_params = UpdateChatParams {
name: None,
metadata: Some(json!({"key": "value"})),
rag_ids: None,
};
update_chat_settings(&pool, &chat_id, update_params)
.await
.expect("Failed to update chat settings");
// Validate metadata update
let settings = select_chat_settings(&pool, &chat_id)
.await
.expect("Failed to get chat settings");
assert_eq!(settings.metadata, json!({"key": "value"}));
// Update rag_ids and metadata together
let update_params = UpdateChatParams {
name: None,
metadata: Some(json!({"new_key": "new_value"})),
rag_ids: Some(vec!["rag3".to_string(), "rag4".to_string()]),
};
update_chat_settings(&pool, &chat_id, update_params)
.await
.expect("Failed to update chat settings");
// Validate both rag_ids and metadata
let settings = select_chat_settings(&pool, &chat_id)
.await
.expect("Failed to get chat settings");
assert_eq!(
settings.metadata,
json!({"key": "value", "new_key": "new_value"})
);
assert_eq!(settings.rag_ids, vec!["rag3", "rag4"]);
}