mirror of
https://github.com/AppFlowy-IO/AppFlowy-Cloud.git
synced 2025-04-19 03:24:42 -04:00
* chore: chat response with format * chore: update prompt * chore: update test * chore: update test * chore: fix stress test * chore: fix test * chore: test * chore: test * chore: fix stress test * chore: fix test
456 lines
12 KiB
Rust
456 lines
12 KiB
Rust
use crate::ai_test::util::read_text_from_asset;
|
|
|
|
use appflowy_ai_client::dto::{
|
|
ChatQuestionQuery, OutputContent, OutputContentMetadata, OutputLayout, ResponseFormat,
|
|
};
|
|
use assert_json_diff::assert_json_include;
|
|
use client_api::entity::{QuestionStream, QuestionStreamValue};
|
|
use client_api_test::{ai_test_enabled, TestClient};
|
|
use futures_util::StreamExt;
|
|
use serde_json::json;
|
|
use shared_entity::dto::chat_dto::{
|
|
ChatMessageMetadata, ChatRAGData, CreateAnswerMessageParams, CreateChatMessageParams,
|
|
CreateChatParams, MessageCursor, UpdateChatParams,
|
|
};
|
|
|
|
#[tokio::test]
|
|
async fn update_chat_settings_test() {
|
|
if !ai_test_enabled() {
|
|
return;
|
|
}
|
|
|
|
let test_client = TestClient::new_user_without_ws_conn().await;
|
|
let workspace_id = test_client.workspace_id().await;
|
|
|
|
let chat_id = uuid::Uuid::new_v4().to_string();
|
|
let params = CreateChatParams {
|
|
chat_id: chat_id.clone(),
|
|
name: "my first chat".to_string(),
|
|
rag_ids: vec![],
|
|
};
|
|
|
|
test_client
|
|
.api_client
|
|
.create_chat(&workspace_id, params)
|
|
.await
|
|
.unwrap();
|
|
|
|
// Update name and rag_ids
|
|
test_client
|
|
.api_client
|
|
.update_chat_settings(
|
|
&workspace_id,
|
|
&chat_id,
|
|
UpdateChatParams {
|
|
name: Some("my second chat".to_string()),
|
|
metadata: None,
|
|
rag_ids: Some(vec!["rag1".to_string(), "rag2".to_string()]),
|
|
},
|
|
)
|
|
.await
|
|
.unwrap();
|
|
|
|
// Get chat settings and check if the name and rag_ids are updated
|
|
let settings = test_client
|
|
.api_client
|
|
.get_chat_settings(&workspace_id, &chat_id)
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(settings.name, "my second chat");
|
|
assert_eq!(
|
|
settings.rag_ids,
|
|
vec!["rag1".to_string(), "rag2".to_string()]
|
|
);
|
|
|
|
// Update chat metadata
|
|
test_client
|
|
.api_client
|
|
.update_chat_settings(
|
|
&workspace_id,
|
|
&chat_id,
|
|
UpdateChatParams {
|
|
name: None,
|
|
metadata: Some(json!({"1": "A"})),
|
|
rag_ids: None,
|
|
},
|
|
)
|
|
.await
|
|
.unwrap();
|
|
test_client
|
|
.api_client
|
|
.update_chat_settings(
|
|
&workspace_id,
|
|
&chat_id,
|
|
UpdateChatParams {
|
|
name: None,
|
|
metadata: Some(json!({"2": "B"})),
|
|
rag_ids: None,
|
|
},
|
|
)
|
|
.await
|
|
.unwrap();
|
|
|
|
// check if the metadata is updated
|
|
let settings = test_client
|
|
.api_client
|
|
.get_chat_settings(&workspace_id, &chat_id)
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(settings.metadata, json!({"1": "A", "2": "B"}));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn create_chat_and_create_messages_test() {
|
|
if !ai_test_enabled() {
|
|
return;
|
|
}
|
|
|
|
let test_client = TestClient::new_user_without_ws_conn().await;
|
|
let workspace_id = test_client.workspace_id().await;
|
|
|
|
let chat_id = uuid::Uuid::new_v4().to_string();
|
|
let params = CreateChatParams {
|
|
chat_id: chat_id.clone(),
|
|
name: "my first chat".to_string(),
|
|
rag_ids: vec![],
|
|
};
|
|
|
|
test_client
|
|
.api_client
|
|
.create_chat(&workspace_id, params)
|
|
.await
|
|
.unwrap();
|
|
|
|
let mut messages = vec![];
|
|
for i in 0..10 {
|
|
let params = CreateChatMessageParams::new_system(format!("hello world {}", i));
|
|
let question = test_client
|
|
.api_client
|
|
.create_question(&workspace_id, &chat_id, params)
|
|
.await
|
|
.unwrap();
|
|
messages.push(question);
|
|
}
|
|
// DESC is the default order
|
|
messages.reverse();
|
|
|
|
// get messages before third message. it should return first two messages even though we asked
|
|
// for 10 messages
|
|
assert_eq!(messages[7].content, "hello world 2");
|
|
let message_before_third = test_client
|
|
.api_client
|
|
.get_chat_messages(
|
|
&workspace_id,
|
|
&chat_id,
|
|
MessageCursor::BeforeMessageId(messages[7].message_id),
|
|
10,
|
|
)
|
|
.await
|
|
.unwrap();
|
|
assert!(!message_before_third.has_more);
|
|
assert_eq!(message_before_third.messages.len(), 2);
|
|
assert_eq!(message_before_third.messages[0].content, "hello world 1");
|
|
assert_eq!(message_before_third.messages[1].content, "hello world 0");
|
|
|
|
// get message after third message
|
|
assert_eq!(messages[2].content, "hello world 7");
|
|
let message_after_third = test_client
|
|
.api_client
|
|
.get_chat_messages(
|
|
&workspace_id,
|
|
&chat_id,
|
|
MessageCursor::AfterMessageId(messages[2].message_id),
|
|
2,
|
|
)
|
|
.await
|
|
.unwrap();
|
|
assert!(!message_after_third.has_more);
|
|
assert_eq!(message_after_third.messages.len(), 2);
|
|
assert_eq!(message_after_third.messages[0].content, "hello world 9");
|
|
assert_eq!(message_after_third.messages[1].content, "hello world 8");
|
|
|
|
let next_back = test_client
|
|
.api_client
|
|
.get_chat_messages(&workspace_id, &chat_id, MessageCursor::NextBack, 3)
|
|
.await
|
|
.unwrap();
|
|
assert!(next_back.has_more);
|
|
assert_eq!(next_back.messages.len(), 3);
|
|
assert_eq!(next_back.messages[0].content, "hello world 9");
|
|
assert_eq!(next_back.messages[1].content, "hello world 8");
|
|
|
|
let next_back = test_client
|
|
.api_client
|
|
.get_chat_messages(&workspace_id, &chat_id, MessageCursor::NextBack, 100)
|
|
.await
|
|
.unwrap();
|
|
assert!(!next_back.has_more);
|
|
assert_eq!(next_back.messages.len(), 10);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn chat_qa_test() {
|
|
if !ai_test_enabled() {
|
|
return;
|
|
}
|
|
let test_client = TestClient::new_user_without_ws_conn().await;
|
|
let workspace_id = test_client.workspace_id().await;
|
|
let chat_id = uuid::Uuid::new_v4().to_string();
|
|
let params = CreateChatParams {
|
|
chat_id: chat_id.clone(),
|
|
name: "new chat".to_string(),
|
|
rag_ids: vec![],
|
|
};
|
|
|
|
test_client
|
|
.api_client
|
|
.create_chat(&workspace_id, params)
|
|
.await
|
|
.unwrap();
|
|
|
|
let content = read_text_from_asset("my_profile.txt");
|
|
let metadata = ChatMessageMetadata {
|
|
data: ChatRAGData::new_text(content),
|
|
id: "123".to_string(),
|
|
name: "test context".to_string(),
|
|
source: "user added".to_string(),
|
|
extra: Some(json!({"created_at": 123})),
|
|
};
|
|
|
|
let params = CreateChatMessageParams::new_user("Where lucas live?").with_metadata(metadata);
|
|
let question = test_client
|
|
.api_client
|
|
.create_question(&workspace_id, &chat_id, params)
|
|
.await
|
|
.unwrap();
|
|
let expected = json!({
|
|
"id": "123",
|
|
"name": "test context",
|
|
"source": "user added",
|
|
"extra": {
|
|
"created_at": 123
|
|
}
|
|
});
|
|
assert_json_include!(
|
|
actual: json!(question.meta_data[0]),
|
|
expected: expected
|
|
);
|
|
|
|
let related_questions = test_client
|
|
.api_client
|
|
.get_chat_related_question(&workspace_id, &chat_id, question.message_id)
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(related_questions.items.len(), 3);
|
|
println!("related questions: {:?}", related_questions.items);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn generate_chat_message_answer_test() {
|
|
if !ai_test_enabled() {
|
|
return;
|
|
}
|
|
let test_client = TestClient::new_user_without_ws_conn().await;
|
|
let workspace_id = test_client.workspace_id().await;
|
|
let chat_id = uuid::Uuid::new_v4().to_string();
|
|
let params = CreateChatParams {
|
|
chat_id: chat_id.clone(),
|
|
name: "my second chat".to_string(),
|
|
rag_ids: vec![],
|
|
};
|
|
|
|
test_client
|
|
.api_client
|
|
.create_chat(&workspace_id, params)
|
|
.await
|
|
.unwrap();
|
|
let params = CreateChatMessageParams::new_user("Hello");
|
|
let question = test_client
|
|
.api_client
|
|
.create_question(&workspace_id, &chat_id, params)
|
|
.await
|
|
.unwrap();
|
|
let answer_stream = test_client
|
|
.api_client
|
|
.stream_answer_v2(&workspace_id, &chat_id, question.message_id)
|
|
.await
|
|
.unwrap();
|
|
let answer = collect_answer(answer_stream).await;
|
|
assert!(!answer.is_empty());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn get_format_question_message_test() {
|
|
if !ai_test_enabled() {
|
|
return;
|
|
}
|
|
|
|
let test_client = TestClient::new_user_without_ws_conn().await;
|
|
let workspace_id = test_client.workspace_id().await;
|
|
let chat_id = uuid::Uuid::new_v4().to_string();
|
|
let params = CreateChatParams {
|
|
chat_id: chat_id.clone(),
|
|
name: "my ai chat".to_string(),
|
|
rag_ids: vec![],
|
|
};
|
|
|
|
test_client
|
|
.api_client
|
|
.create_chat(&workspace_id, params)
|
|
.await
|
|
.unwrap();
|
|
|
|
let params = CreateChatMessageParams::new_user(
|
|
"what is the different between Rust and c++? Give me three points",
|
|
);
|
|
let question = test_client
|
|
.api_client
|
|
.create_question(&workspace_id, &chat_id, params)
|
|
.await
|
|
.unwrap();
|
|
|
|
let query = ChatQuestionQuery {
|
|
chat_id,
|
|
question_id: question.message_id,
|
|
format: ResponseFormat {
|
|
output_layout: OutputLayout::SimpleTable,
|
|
output_content: OutputContent::TEXT,
|
|
output_content_metadata: None,
|
|
},
|
|
};
|
|
|
|
let answer_stream = test_client
|
|
.api_client
|
|
.stream_answer_v3(&workspace_id, query)
|
|
.await
|
|
.unwrap();
|
|
let answer = collect_answer(answer_stream).await;
|
|
println!("answer:\n{}", answer);
|
|
assert!(!answer.is_empty());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn get_text_with_image_message_test() {
|
|
if !ai_test_enabled() {
|
|
return;
|
|
}
|
|
|
|
let test_client = TestClient::new_user_without_ws_conn().await;
|
|
let workspace_id = test_client.workspace_id().await;
|
|
let chat_id = uuid::Uuid::new_v4().to_string();
|
|
let params = CreateChatParams {
|
|
chat_id: chat_id.clone(),
|
|
name: "my ai chat".to_string(),
|
|
rag_ids: vec![],
|
|
};
|
|
|
|
test_client
|
|
.api_client
|
|
.create_chat(&workspace_id, params)
|
|
.await
|
|
.unwrap();
|
|
|
|
let params = CreateChatMessageParams::new_user(
|
|
"I have a little cat. It is black with big eyes, short legs and a long tail",
|
|
);
|
|
let question = test_client
|
|
.api_client
|
|
.create_question(&workspace_id, &chat_id, params)
|
|
.await
|
|
.unwrap();
|
|
|
|
let query = ChatQuestionQuery {
|
|
chat_id,
|
|
question_id: question.message_id,
|
|
format: ResponseFormat {
|
|
output_layout: OutputLayout::SimpleTable,
|
|
output_content: OutputContent::RichTextImage,
|
|
output_content_metadata: Some(OutputContentMetadata {
|
|
custom_image_prompt: None,
|
|
image_model: "dall-e-3".to_string(),
|
|
size: None,
|
|
quality: None,
|
|
}),
|
|
},
|
|
};
|
|
|
|
let answer_stream = test_client
|
|
.api_client
|
|
.stream_answer_v3(&workspace_id, query)
|
|
.await
|
|
.unwrap();
|
|
let answer = collect_answer(answer_stream).await;
|
|
println!("answer:\n{}", answer);
|
|
assert!(!answer.is_empty());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn get_question_message_test() {
|
|
if !ai_test_enabled() {
|
|
return;
|
|
}
|
|
|
|
let test_client = TestClient::new_user_without_ws_conn().await;
|
|
let workspace_id = test_client.workspace_id().await;
|
|
let chat_id = uuid::Uuid::new_v4().to_string();
|
|
let params = CreateChatParams {
|
|
chat_id: chat_id.clone(),
|
|
name: "my ai chat".to_string(),
|
|
rag_ids: vec![],
|
|
};
|
|
|
|
test_client
|
|
.api_client
|
|
.create_chat(&workspace_id, params)
|
|
.await
|
|
.unwrap();
|
|
|
|
let params = CreateChatMessageParams::new_user("where is singapore?");
|
|
let question = test_client
|
|
.api_client
|
|
.create_question(&workspace_id, &chat_id, params)
|
|
.await
|
|
.unwrap();
|
|
|
|
let answer = test_client
|
|
.api_client
|
|
.get_answer(&workspace_id, &chat_id, question.message_id)
|
|
.await
|
|
.unwrap();
|
|
|
|
test_client
|
|
.api_client
|
|
.save_answer(
|
|
&workspace_id,
|
|
&chat_id,
|
|
CreateAnswerMessageParams {
|
|
content: answer.content,
|
|
metadata: None,
|
|
question_message_id: question.message_id,
|
|
},
|
|
)
|
|
.await
|
|
.unwrap();
|
|
|
|
let find_question = test_client
|
|
.api_client
|
|
.get_question_message_from_answer_id(&workspace_id, &chat_id, answer.message_id)
|
|
.await
|
|
.unwrap()
|
|
.unwrap();
|
|
|
|
assert_eq!(find_question.reply_message_id.unwrap(), answer.message_id);
|
|
}
|
|
|
|
async fn collect_answer(mut stream: QuestionStream) -> String {
|
|
let mut answer = String::new();
|
|
while let Some(value) = stream.next().await {
|
|
match value.unwrap() {
|
|
QuestionStreamValue::Answer { value } => {
|
|
answer.push_str(&value);
|
|
},
|
|
QuestionStreamValue::Metadata { .. } => {},
|
|
}
|
|
}
|
|
answer
|
|
}
|