mirror of
https://github.com/AppFlowy-IO/AppFlowy-Cloud.git
synced 2025-04-19 03:24:42 -04:00
parent
445d3af5fa
commit
afcd1130c3
12 changed files with 358 additions and 64 deletions
34
.sqlx/query-66218110851919b05b95b008a17547547d23f6baeeff8a5521b2b246126adc34.json
generated
Normal file
34
.sqlx/query-66218110851919b05b95b008a17547547d23f6baeeff8a5521b2b246126adc34.json
generated
Normal file
|
@ -0,0 +1,34 @@
|
|||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT name, meta_data, rag_ids\n FROM af_chat\n WHERE chat_id = $1 AND deleted_at IS NULL\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "name",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "meta_data",
|
||||
"type_info": "Jsonb"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "rag_ids",
|
||||
"type_info": "Jsonb"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "66218110851919b05b95b008a17547547d23f6baeeff8a5521b2b246126adc34"
|
||||
}
|
|
@ -25,7 +25,7 @@ pub fn load_env() {
|
|||
});
|
||||
}
|
||||
|
||||
pub fn local_ai_test_enabled() -> bool {
|
||||
pub fn ai_test_enabled() -> bool {
|
||||
// In appflowy GitHub CI, we enable 'ai-test-enabled' feature by default, so even if the env var is not set,
|
||||
// we still enable the local ai test.
|
||||
if cfg!(feature = "ai-test-enabled") {
|
||||
|
|
|
@ -13,6 +13,7 @@ use shared_entity::dto::ai_dto::{
|
|||
CalculateSimilarityParams, RepeatedRelatedQuestion, SimilarityResponse, STREAM_ANSWER_KEY,
|
||||
STREAM_METADATA_KEY,
|
||||
};
|
||||
use shared_entity::dto::chat_dto::{ChatSettings, UpdateChatParams};
|
||||
use shared_entity::response::{AppResponse, AppResponseError};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
@ -37,6 +38,45 @@ impl Client {
|
|||
AppResponse::<()>::from_response(resp).await?.into_error()
|
||||
}
|
||||
|
||||
pub async fn update_chat_settings(
|
||||
&self,
|
||||
workspace_id: &str,
|
||||
chat_id: &str,
|
||||
params: UpdateChatParams,
|
||||
) -> Result<(), AppResponseError> {
|
||||
let url = format!(
|
||||
"{}/api/chat/{workspace_id}/{chat_id}/settings",
|
||||
self.base_url
|
||||
);
|
||||
let resp = self
|
||||
.http_client_with_auth(Method::POST, &url)
|
||||
.await?
|
||||
.json(¶ms)
|
||||
.send()
|
||||
.await?;
|
||||
log_request_id(&resp);
|
||||
AppResponse::<()>::from_response(resp).await?.into_error()
|
||||
}
|
||||
pub async fn get_chat_settings(
|
||||
&self,
|
||||
workspace_id: &str,
|
||||
chat_id: &str,
|
||||
) -> Result<ChatSettings, AppResponseError> {
|
||||
let url = format!(
|
||||
"{}/api/chat/{workspace_id}/{chat_id}/settings",
|
||||
self.base_url
|
||||
);
|
||||
let resp = self
|
||||
.http_client_with_auth(Method::GET, &url)
|
||||
.await?
|
||||
.send()
|
||||
.await?;
|
||||
log_request_id(&resp);
|
||||
AppResponse::<ChatSettings>::from_response(resp)
|
||||
.await?
|
||||
.into_data()
|
||||
}
|
||||
|
||||
/// Delete a chat for given chat_id
|
||||
pub async fn delete_chat(
|
||||
&self,
|
||||
|
@ -103,10 +143,10 @@ impl Client {
|
|||
&self,
|
||||
workspace_id: &str,
|
||||
chat_id: &str,
|
||||
question_message_id: i64,
|
||||
question_id: i64,
|
||||
) -> Result<QuestionStream, AppResponseError> {
|
||||
let url = format!(
|
||||
"{}/api/chat/{workspace_id}/{chat_id}/{question_message_id}/v2/answer/stream",
|
||||
"{}/api/chat/{workspace_id}/{chat_id}/{question_id}/v2/answer/stream",
|
||||
self.base_url
|
||||
);
|
||||
let resp = self
|
||||
|
@ -193,7 +233,10 @@ impl Client {
|
|||
offset: MessageCursor,
|
||||
limit: u64,
|
||||
) -> Result<RepeatedChatMessage, AppResponseError> {
|
||||
let mut url = format!("{}/api/chat/{workspace_id}/{chat_id}", self.base_url);
|
||||
let mut url = format!(
|
||||
"{}/api/chat/{workspace_id}/{chat_id}/message",
|
||||
self.base_url
|
||||
);
|
||||
let mut query_params = vec![("limit", limit.to_string())];
|
||||
match offset {
|
||||
MessageCursor::Offset(offset_value) => {
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
use crate::pg_row::AFChatRow;
|
||||
use crate::workspace::is_workspace_exist;
|
||||
use anyhow::anyhow;
|
||||
use app_error::AppError;
|
||||
use chrono::{DateTime, Utc};
|
||||
use shared_entity::dto::chat_dto::{
|
||||
ChatAuthor, ChatMessage, ChatMessageMetadata, CreateChatParams, GetChatMessageParams,
|
||||
MessageCursor, RepeatedChatMessage, UpdateChatMessageContentParams, UpdateChatMessageMetaParams,
|
||||
UpdateChatParams,
|
||||
ChatAuthor, ChatMessage, ChatMessageMetadata, ChatSettings, CreateChatParams,
|
||||
GetChatMessageParams, MessageCursor, RepeatedChatMessage, UpdateChatMessageContentParams,
|
||||
UpdateChatMessageMetaParams, UpdateChatParams,
|
||||
};
|
||||
|
||||
use serde_json::json;
|
||||
|
@ -19,19 +18,13 @@ use tracing::warn;
|
|||
|
||||
use uuid::Uuid;
|
||||
|
||||
pub async fn insert_chat(
|
||||
txn: &mut Transaction<'_, Postgres>,
|
||||
pub async fn insert_chat<'a, E: Executor<'a, Database = Postgres>>(
|
||||
executor: E,
|
||||
workspace_id: &str,
|
||||
params: CreateChatParams,
|
||||
) -> Result<(), AppError> {
|
||||
let chat_id = Uuid::from_str(¶ms.chat_id)?;
|
||||
let workspace_id = Uuid::from_str(workspace_id)?;
|
||||
if !is_workspace_exist(txn.deref_mut(), &workspace_id).await? {
|
||||
return Err(AppError::RecordNotFound(format!(
|
||||
"workspace with given id:{} is not found",
|
||||
workspace_id
|
||||
)));
|
||||
}
|
||||
let rag_ids = json!(params.rag_ids);
|
||||
sqlx::query!(
|
||||
r#"
|
||||
|
@ -43,25 +36,40 @@ pub async fn insert_chat(
|
|||
workspace_id,
|
||||
rag_ids,
|
||||
)
|
||||
.execute(txn.deref_mut())
|
||||
.execute(executor)
|
||||
.await
|
||||
.map_err(|err| AppError::Internal(anyhow!("Failed to insert chat: {}", err)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Updates specific fields of a chat record in the database using transactional queries.
|
||||
///
|
||||
/// This function dynamically builds an SQL `UPDATE` query based on the provided parameters to
|
||||
/// update fields of a specific chat record identified by `chat_id`. It uses a transaction to ensure
|
||||
/// that the update operation is atomic.
|
||||
///
|
||||
pub async fn update_chat(
|
||||
txn: &mut Transaction<'_, Postgres>,
|
||||
pub async fn select_chat_settings<'a, E: Executor<'a, Database = Postgres>>(
|
||||
executor: E,
|
||||
chat_id: &Uuid,
|
||||
) -> Result<ChatSettings, AppError> {
|
||||
let row = sqlx::query!(
|
||||
r#"
|
||||
SELECT name, meta_data, rag_ids
|
||||
FROM af_chat
|
||||
WHERE chat_id = $1 AND deleted_at IS NULL
|
||||
"#,
|
||||
&chat_id,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await?;
|
||||
let rag_ids = serde_json::from_value::<Vec<String>>(row.rag_ids).unwrap_or_default();
|
||||
Ok(ChatSettings {
|
||||
name: row.name,
|
||||
rag_ids,
|
||||
metadata: row.meta_data,
|
||||
})
|
||||
}
|
||||
pub async fn update_chat_settings<'a, E: Executor<'a, Database = Postgres>>(
|
||||
executor: E,
|
||||
chat_id: &Uuid,
|
||||
params: UpdateChatParams,
|
||||
) -> Result<(), AppError> {
|
||||
let mut query_parts = vec!["UPDATE af_chat SET".to_string()];
|
||||
let mut query_parts = vec![];
|
||||
let mut args = PgArguments::default();
|
||||
let mut current_param_pos = 1; // Start counting SQL parameters from 1
|
||||
|
||||
|
@ -77,7 +85,7 @@ pub async fn update_chat(
|
|||
}
|
||||
|
||||
if let Some(ref metadata) = params.metadata {
|
||||
query_parts.push(format!("metadata = metadata || ${}", current_param_pos));
|
||||
query_parts.push(format!("meta_data = meta_data || ${}", current_param_pos));
|
||||
args
|
||||
.add(json!(metadata))
|
||||
.map_err(|err| AppError::SqlxArgEncodingError {
|
||||
|
@ -87,7 +95,27 @@ pub async fn update_chat(
|
|||
current_param_pos += 1;
|
||||
}
|
||||
|
||||
query_parts.push(format!("WHERE chat_id = ${}", current_param_pos));
|
||||
if let Some(rag_ids) = params.rag_ids {
|
||||
query_parts.push(format!("rag_ids = ${}", current_param_pos));
|
||||
args
|
||||
.add(json!(rag_ids))
|
||||
.map_err(|err| AppError::SqlxArgEncodingError {
|
||||
desc: format!("unable to encode rag ids for chat id {}", chat_id),
|
||||
err,
|
||||
})?;
|
||||
current_param_pos += 1;
|
||||
}
|
||||
|
||||
if query_parts.is_empty() {
|
||||
// If no fields to update, skip execution
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let query = format!(
|
||||
"UPDATE af_chat SET {} WHERE chat_id = ${}",
|
||||
query_parts.join(", "),
|
||||
current_param_pos
|
||||
);
|
||||
args
|
||||
.add(chat_id)
|
||||
.map_err(|err| AppError::SqlxArgEncodingError {
|
||||
|
@ -95,9 +123,11 @@ pub async fn update_chat(
|
|||
err,
|
||||
})?;
|
||||
|
||||
let query = query_parts.join(", ") + ";";
|
||||
let query = sqlx::query_with(&query, args);
|
||||
query.execute(txn.deref_mut()).await?;
|
||||
sqlx::query_with(&query, args)
|
||||
.execute(executor)
|
||||
.await
|
||||
.map_err(|err| AppError::Internal(anyhow!("Failed to update chat settings: {}", err)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
|
@ -18,13 +18,13 @@ pub struct CreateChatParams {
|
|||
|
||||
#[derive(Debug, Clone, Validate, Serialize, Deserialize)]
|
||||
pub struct UpdateChatParams {
|
||||
#[validate(custom = "validate_not_empty_str")]
|
||||
pub chat_id: String,
|
||||
|
||||
#[validate(custom = "validate_not_empty_str")]
|
||||
pub name: Option<String>,
|
||||
|
||||
/// Key-value pairs of metadata to be updated.
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
|
||||
pub rag_ids: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Validate, Serialize, Deserialize)]
|
||||
|
@ -308,6 +308,14 @@ pub struct RepeatedChatMessage {
|
|||
pub total: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatSettings {
|
||||
// Currently we have not used the `name` field in the ChatSettings
|
||||
pub name: String,
|
||||
pub rag_ids: Vec<String>,
|
||||
pub metadata: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, Serialize_repr, Deserialize_repr)]
|
||||
#[repr(u8)]
|
||||
pub enum ChatAuthorType {
|
||||
|
|
|
@ -1355,7 +1355,7 @@ async fn get_encode_collab_from_bytes(
|
|||
.map_err(|err| ImportError::Internal(err.into()))?,
|
||||
)
|
||||
},
|
||||
Err(err) => return Err(err.into()),
|
||||
Err(err) => Err(err.into()),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -6,18 +6,20 @@ use crate::state::AppState;
|
|||
use actix_web::web::{Data, Json};
|
||||
use actix_web::{web, HttpRequest, HttpResponse, Scope};
|
||||
|
||||
use crate::api::util::ai_model_from_header;
|
||||
use app_error::AppError;
|
||||
use appflowy_ai_client::dto::{CreateChatContext, RepeatedRelatedQuestion};
|
||||
use authentication::jwt::UserUuid;
|
||||
use bytes::Bytes;
|
||||
use database::chat;
|
||||
use futures::Stream;
|
||||
use futures_util::stream;
|
||||
use futures_util::{FutureExt, TryStreamExt};
|
||||
use pin_project::pin_project;
|
||||
use shared_entity::dto::chat_dto::{
|
||||
ChatAuthor, ChatMessage, CreateAnswerMessageParams, CreateChatMessageParams,
|
||||
ChatAuthor, ChatMessage, ChatSettings, CreateAnswerMessageParams, CreateChatMessageParams,
|
||||
CreateChatMessageParamsV2, CreateChatParams, GetChatMessageParams, MessageCursor,
|
||||
RepeatedChatMessage, UpdateChatMessageContentParams,
|
||||
RepeatedChatMessage, UpdateChatMessageContentParams, UpdateChatParams,
|
||||
};
|
||||
use shared_entity::response::{AppResponse, JsonAppResponse};
|
||||
use std::collections::HashMap;
|
||||
|
@ -25,17 +27,12 @@ use std::pin::Pin;
|
|||
use std::task::{Context, Poll};
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::task;
|
||||
|
||||
use database::chat;
|
||||
|
||||
use crate::api::util::ai_model_from_header;
|
||||
|
||||
use database::chat::chat_ops::insert_answer_message;
|
||||
use tracing::{error, instrument, trace};
|
||||
use uuid::Uuid;
|
||||
use validator::Validate;
|
||||
pub fn chat_scope() -> Scope {
|
||||
web::scope("/api/chat/{workspace_id}")
|
||||
// Chat management
|
||||
// Chat CRUD
|
||||
.service(
|
||||
web::resource("")
|
||||
.route(web::post().to(create_chat_handler))
|
||||
|
@ -43,13 +40,22 @@ pub fn chat_scope() -> Scope {
|
|||
.service(
|
||||
web::resource("/{chat_id}")
|
||||
.route(web::delete().to(delete_chat_handler))
|
||||
// Deprecated, use /message instead
|
||||
.route(web::get().to(get_chat_message_handler))
|
||||
)
|
||||
|
||||
// Settings
|
||||
.service(
|
||||
web::resource("/{chat_id}/settings")
|
||||
.route(web::get().to(get_chat_settings_handler))
|
||||
.route(web::post().to(update_chat_settings_handler))
|
||||
)
|
||||
|
||||
// Message management
|
||||
.service(
|
||||
web::resource("/{chat_id}/message")
|
||||
.route(web::put().to(update_question_handler))
|
||||
.route(web::get().to(get_chat_message_handler))
|
||||
)
|
||||
.service(
|
||||
web::resource("/{chat_id}/message/question")
|
||||
|
@ -202,7 +208,7 @@ async fn save_answer_handler(
|
|||
payload.validate().map_err(AppError::from)?;
|
||||
|
||||
let (_workspace_id, chat_id) = path.into_inner();
|
||||
let message = insert_answer_message(
|
||||
let message = database::chat::chat_ops::insert_answer_message(
|
||||
&state.pg_pool,
|
||||
ChatAuthor::ai(),
|
||||
&chat_id,
|
||||
|
@ -343,6 +349,28 @@ async fn get_chat_message_handler(
|
|||
Ok(AppResponse::Ok().with_data(messages).into())
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn get_chat_settings_handler(
|
||||
path: web::Path<(String, String)>,
|
||||
state: Data<AppState>,
|
||||
) -> actix_web::Result<JsonAppResponse<ChatSettings>> {
|
||||
let (_, chat_id) = path.into_inner();
|
||||
let chat_id_uuid = Uuid::parse_str(&chat_id).map_err(AppError::from)?;
|
||||
let settings = chat::chat_ops::select_chat_settings(&state.pg_pool, &chat_id_uuid).await?;
|
||||
Ok(AppResponse::Ok().with_data(settings).into())
|
||||
}
|
||||
|
||||
async fn update_chat_settings_handler(
|
||||
path: web::Path<(String, String)>,
|
||||
state: Data<AppState>,
|
||||
payload: Json<UpdateChatParams>,
|
||||
) -> actix_web::Result<JsonAppResponse<()>> {
|
||||
let (_workspace_id, chat_id) = path.into_inner();
|
||||
let chat_id_uuid = Uuid::parse_str(&chat_id).map_err(AppError::from)?;
|
||||
chat::chat_ops::update_chat_settings(&state.pg_pool, &chat_id_uuid, payload.into_inner()).await?;
|
||||
Ok(AppResponse::Ok().into())
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
pub struct FinalAnswerStream<S, F> {
|
||||
#[pin]
|
||||
|
|
|
@ -30,9 +30,7 @@ pub(crate) async fn create_chat(
|
|||
params.validate()?;
|
||||
trace!("[Chat] create chat {:?}", params);
|
||||
|
||||
let mut txn = pg_pool.begin().await?;
|
||||
insert_chat(&mut txn, workspace_id, params).await?;
|
||||
txn.commit().await?;
|
||||
insert_chat(pg_pool, workspace_id, params).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
|
@ -2,16 +2,103 @@ use crate::ai_test::util::read_text_from_asset;
|
|||
|
||||
use assert_json_diff::{assert_json_eq, assert_json_include};
|
||||
use client_api::entity::{QuestionStream, QuestionStreamValue};
|
||||
use client_api_test::{local_ai_test_enabled, TestClient};
|
||||
use client_api_test::{ai_test_enabled, TestClient};
|
||||
use futures_util::StreamExt;
|
||||
use serde_json::json;
|
||||
use shared_entity::dto::chat_dto::{
|
||||
ChatMessageMetadata, ChatRAGData, CreateChatMessageParams, CreateChatParams, MessageCursor,
|
||||
UpdateChatParams,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn update_chat_settings_test() {
|
||||
if !ai_test_enabled() {
|
||||
return;
|
||||
}
|
||||
|
||||
let test_client = TestClient::new_user_without_ws_conn().await;
|
||||
let workspace_id = test_client.workspace_id().await;
|
||||
|
||||
let chat_id = uuid::Uuid::new_v4().to_string();
|
||||
let params = CreateChatParams {
|
||||
chat_id: chat_id.clone(),
|
||||
name: "my first chat".to_string(),
|
||||
rag_ids: vec![],
|
||||
};
|
||||
|
||||
test_client
|
||||
.api_client
|
||||
.create_chat(&workspace_id, params)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Update name and rag_ids
|
||||
test_client
|
||||
.api_client
|
||||
.update_chat_settings(
|
||||
&workspace_id,
|
||||
&chat_id,
|
||||
UpdateChatParams {
|
||||
name: Some("my second chat".to_string()),
|
||||
metadata: None,
|
||||
rag_ids: Some(vec!["rag1".to_string(), "rag2".to_string()]),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Get chat settings and check if the name and rag_ids are updated
|
||||
let settings = test_client
|
||||
.api_client
|
||||
.get_chat_settings(&workspace_id, &chat_id)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(settings.name, "my second chat");
|
||||
assert_eq!(
|
||||
settings.rag_ids,
|
||||
vec!["rag1".to_string(), "rag2".to_string()]
|
||||
);
|
||||
|
||||
// Update chat metadata
|
||||
test_client
|
||||
.api_client
|
||||
.update_chat_settings(
|
||||
&workspace_id,
|
||||
&chat_id,
|
||||
UpdateChatParams {
|
||||
name: None,
|
||||
metadata: Some(json!({"1": "A"})),
|
||||
rag_ids: None,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
test_client
|
||||
.api_client
|
||||
.update_chat_settings(
|
||||
&workspace_id,
|
||||
&chat_id,
|
||||
UpdateChatParams {
|
||||
name: None,
|
||||
metadata: Some(json!({"2": "B"})),
|
||||
rag_ids: None,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// check if the metadata is updated
|
||||
let settings = test_client
|
||||
.api_client
|
||||
.get_chat_settings(&workspace_id, &chat_id)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(settings.metadata, json!({"1": "A", "2": "B"}));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn create_chat_and_create_messages_test() {
|
||||
if !local_ai_test_enabled() {
|
||||
if !ai_test_enabled() {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -100,7 +187,7 @@ async fn create_chat_and_create_messages_test() {
|
|||
|
||||
#[tokio::test]
|
||||
async fn chat_qa_test() {
|
||||
if !local_ai_test_enabled() {
|
||||
if !ai_test_enabled() {
|
||||
return;
|
||||
}
|
||||
let test_client = TestClient::new_user_without_ws_conn().await;
|
||||
|
@ -175,7 +262,7 @@ async fn chat_qa_test() {
|
|||
|
||||
#[tokio::test]
|
||||
async fn generate_chat_message_answer_test() {
|
||||
if !local_ai_test_enabled() {
|
||||
if !ai_test_enabled() {
|
||||
return;
|
||||
}
|
||||
let test_client = TestClient::new_user_without_ws_conn().await;
|
||||
|
@ -209,7 +296,7 @@ async fn generate_chat_message_answer_test() {
|
|||
|
||||
#[tokio::test]
|
||||
async fn create_chat_context_test() {
|
||||
if !local_ai_test_enabled() {
|
||||
if !ai_test_enabled() {
|
||||
return;
|
||||
}
|
||||
let test_client = TestClient::new_user_without_ws_conn().await;
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
use appflowy_ai_client::dto::{AIModel, CompletionType};
|
||||
use client_api_test::{local_ai_test_enabled, TestClient};
|
||||
use client_api_test::{ai_test_enabled, TestClient};
|
||||
use shared_entity::dto::ai_dto::CompleteTextParams;
|
||||
|
||||
#[tokio::test]
|
||||
async fn improve_writing_test() {
|
||||
if !local_ai_test_enabled() {
|
||||
if !ai_test_enabled() {
|
||||
return;
|
||||
}
|
||||
let test_client = TestClient::new_user().await;
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
use client_api_test::{local_ai_test_enabled, TestClient};
|
||||
use client_api_test::{ai_test_enabled, TestClient};
|
||||
use serde_json::json;
|
||||
use shared_entity::dto::ai_dto::{SummarizeRowData, SummarizeRowParams};
|
||||
|
||||
#[tokio::test]
|
||||
async fn summarize_row_test() {
|
||||
if !local_ai_test_enabled() {
|
||||
if !ai_test_enabled() {
|
||||
return;
|
||||
}
|
||||
let test_client = TestClient::new_user().await;
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
use crate::sql_test::util::{setup_db, test_create_user};
|
||||
use database::chat::chat_ops::{
|
||||
delete_chat, get_all_chat_messages, insert_chat, insert_question_message, select_chat,
|
||||
select_chat_messages,
|
||||
select_chat_messages, select_chat_settings, update_chat_settings,
|
||||
};
|
||||
use serde_json::json;
|
||||
use shared_entity::dto::chat_dto::{
|
||||
ChatAuthor, ChatAuthorType, CreateChatParams, GetChatMessageParams,
|
||||
};
|
||||
|
||||
use shared_entity::dto::chat_dto::UpdateChatParams;
|
||||
use sqlx::PgPool;
|
||||
|
||||
#[sqlx::test(migrations = false)]
|
||||
|
@ -23,9 +25,8 @@ async fn chat_crud_test(pool: PgPool) {
|
|||
let chat_id = uuid::Uuid::new_v4().to_string();
|
||||
// create chat
|
||||
{
|
||||
let mut txn = pool.begin().await.unwrap();
|
||||
insert_chat(
|
||||
&mut txn,
|
||||
&pool,
|
||||
&user.workspace_id,
|
||||
CreateChatParams {
|
||||
chat_id: chat_id.clone(),
|
||||
|
@ -35,7 +36,6 @@ async fn chat_crud_test(pool: PgPool) {
|
|||
)
|
||||
.await
|
||||
.unwrap();
|
||||
txn.commit().await.unwrap();
|
||||
}
|
||||
|
||||
// get chat
|
||||
|
@ -76,9 +76,8 @@ async fn chat_message_crud_test(pool: PgPool) {
|
|||
let chat_id = uuid::Uuid::new_v4().to_string();
|
||||
// create chat
|
||||
{
|
||||
let mut txn = pool.begin().await.unwrap();
|
||||
insert_chat(
|
||||
&mut txn,
|
||||
&pool,
|
||||
&user.workspace_id,
|
||||
CreateChatParams {
|
||||
chat_id: chat_id.clone(),
|
||||
|
@ -88,7 +87,6 @@ async fn chat_message_crud_test(pool: PgPool) {
|
|||
)
|
||||
.await
|
||||
.unwrap();
|
||||
txn.commit().await.unwrap();
|
||||
}
|
||||
|
||||
// create chat messages
|
||||
|
@ -183,3 +181,71 @@ async fn chat_message_crud_test(pool: PgPool) {
|
|||
assert_eq!(messages.len(), 5);
|
||||
}
|
||||
}
|
||||
|
||||
#[sqlx::test(migrations = false)]
|
||||
async fn chat_setting_test(pool: PgPool) {
|
||||
setup_db(&pool).await.unwrap();
|
||||
let user_uuid = uuid::Uuid::new_v4();
|
||||
let name = user_uuid.to_string();
|
||||
let email = format!("{}@appflowy.io", name);
|
||||
let user = test_create_user(&pool, user_uuid, &email, &name)
|
||||
.await
|
||||
.unwrap();
|
||||
let workspace_id = user.workspace_id;
|
||||
let chat_id = uuid::Uuid::new_v4();
|
||||
|
||||
// Insert initial chat data with rag_ids
|
||||
let insert_params = CreateChatParams {
|
||||
chat_id: chat_id.to_string(),
|
||||
name: "Initial Chat".to_string(),
|
||||
rag_ids: vec!["rag1".to_string(), "rag2".to_string()],
|
||||
};
|
||||
|
||||
insert_chat(&pool, &workspace_id, insert_params)
|
||||
.await
|
||||
.expect("Failed to insert chat");
|
||||
|
||||
// Validate inserted rag_ids
|
||||
let settings = select_chat_settings(&pool, &chat_id)
|
||||
.await
|
||||
.expect("Failed to get chat settings");
|
||||
assert_eq!(settings.rag_ids, vec!["rag1", "rag2"]);
|
||||
|
||||
// Update metadata
|
||||
let update_params = UpdateChatParams {
|
||||
name: None,
|
||||
metadata: Some(json!({"key": "value"})),
|
||||
rag_ids: None,
|
||||
};
|
||||
|
||||
update_chat_settings(&pool, &chat_id, update_params)
|
||||
.await
|
||||
.expect("Failed to update chat settings");
|
||||
|
||||
// Validate metadata update
|
||||
let settings = select_chat_settings(&pool, &chat_id)
|
||||
.await
|
||||
.expect("Failed to get chat settings");
|
||||
assert_eq!(settings.metadata, json!({"key": "value"}));
|
||||
|
||||
// Update rag_ids and metadata together
|
||||
let update_params = UpdateChatParams {
|
||||
name: None,
|
||||
metadata: Some(json!({"new_key": "new_value"})),
|
||||
rag_ids: Some(vec!["rag3".to_string(), "rag4".to_string()]),
|
||||
};
|
||||
|
||||
update_chat_settings(&pool, &chat_id, update_params)
|
||||
.await
|
||||
.expect("Failed to update chat settings");
|
||||
|
||||
// Validate both rag_ids and metadata
|
||||
let settings = select_chat_settings(&pool, &chat_id)
|
||||
.await
|
||||
.expect("Failed to get chat settings");
|
||||
assert_eq!(
|
||||
settings.metadata,
|
||||
json!({"key": "value", "new_key": "new_value"})
|
||||
);
|
||||
assert_eq!(settings.rag_ids, vec!["rag3", "rag4"]);
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue