chore: find question message from reply message (#1085)

* chore: find question message from answer message id

* chore: sqlx

* test: fix tests

* test: fix test

* chore: apply code suggestions to 2 files
This commit is contained in:
Richard Shiue 2024-12-19 00:12:53 +08:00 committed by GitHub
parent e758f18d75
commit ecadf8e287
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 216 additions and 9 deletions

View file

@ -0,0 +1,53 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT message_id, content, created_at, author, meta_data, reply_message_id\n FROM af_chat_messages\n WHERE chat_id = $1\n AND reply_message_id = $2\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "message_id",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "content",
"type_info": "Text"
},
{
"ordinal": 2,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 3,
"name": "author",
"type_info": "Jsonb"
},
{
"ordinal": 4,
"name": "meta_data",
"type_info": "Jsonb"
},
{
"ordinal": 5,
"name": "reply_message_id",
"type_info": "Int8"
}
],
"parameters": {
"Left": [
"Uuid",
"Int8"
]
},
"nullable": [
false,
false,
false,
false,
false,
true
]
},
"hash": "794c4ced16801b3e98a62eb44c18c14137dd09b11be73442a7f46b2f938b8445"
}

View file

@ -262,6 +262,28 @@ impl Client {
.into_data()
}
pub async fn get_question_message_from_answer_id(
&self,
workspace_id: &str,
chat_id: &str,
answer_message_id: i64,
) -> Result<Option<ChatMessage>, AppResponseError> {
let url = format!(
"{}/api/chat/{workspace_id}/{chat_id}/message/find_question",
self.base_url
);
let resp = self
.http_client_with_auth(Method::GET, &url)
.await?
.query(&[("answer_message_id", answer_message_id)])
.send()
.await?;
AppResponse::<Option<ChatMessage>>::from_response(resp)
.await?
.into_data()
}
pub async fn calculate_similarity(
&self,
params: CalculateSimilarityParams,

View file

@ -669,3 +669,40 @@ pub async fn select_chat_message_content<'a, E: Executor<'a, Database = Postgres
.await?;
Ok((row.content, row.meta_data))
}
pub async fn select_chat_message_matching_reply_message_id(
txn: &mut Transaction<'_, Postgres>,
chat_id: &str,
reply_message_id: i64,
) -> Result<Option<ChatMessage>, AppError> {
let chat_id = Uuid::from_str(chat_id)?;
let row = sqlx::query!(
r#"
SELECT message_id, content, created_at, author, meta_data, reply_message_id
FROM af_chat_messages
WHERE chat_id = $1
AND reply_message_id = $2
"#,
&chat_id,
reply_message_id
)
.fetch_one(txn.deref_mut())
.await?;
let message = match serde_json::from_value::<ChatAuthor>(row.author) {
Ok(author) => Some(ChatMessage {
author,
message_id: row.message_id,
content: row.content,
created_at: row.created_at,
meta_data: row.meta_data,
reply_message_id: row.reply_message_id,
}),
Err(err) => {
warn!("Failed to deserialize author: {}", err);
None
},
};
Ok(message)
}

View file

@ -1,10 +1,11 @@
use crate::biz::chat::ops::{
create_chat, create_chat_message, delete_chat, generate_chat_message_answer, get_chat_messages,
update_chat_message,
get_question_message, update_chat_message,
};
use crate::state::AppState;
use actix_web::web::{Data, Json};
use actix_web::{web, HttpRequest, HttpResponse, Scope};
use serde::Deserialize;
use crate::api::util::ai_model_from_header;
use app_error::AppError;
@ -69,6 +70,10 @@ pub fn chat_scope() -> Scope {
web::resource("/{chat_id}/message/answer")
.route(web::post().to(save_answer_handler))
)
.service(
web::resource("/{chat_id}/message/find_question")
.route(web::get().to(get_chat_question_message_handler))
)
// AI response generation
.service(
@ -349,6 +354,17 @@ async fn get_chat_message_handler(
Ok(AppResponse::Ok().with_data(messages).into())
}
#[instrument(level = "debug", skip_all, err)]
async fn get_chat_question_message_handler(
path: web::Path<(String, String)>,
query: web::Query<FindQuestionParams>,
state: Data<AppState>,
) -> actix_web::Result<JsonAppResponse<Option<ChatMessage>>> {
let (_workspace_id, chat_id) = path.into_inner();
let message = get_question_message(&state.pg_pool, &chat_id, query.0.answer_message_id).await?;
Ok(AppResponse::Ok().with_data(message).into())
}
#[instrument(level = "debug", skip_all, err)]
async fn get_chat_settings_handler(
path: web::Path<(String, String)>,
@ -501,3 +517,8 @@ where
}
}
}
#[derive(Debug, Deserialize)]
struct FindQuestionParams {
answer_message_id: i64,
}

View file

@ -8,7 +8,7 @@ use database::chat;
use database::chat::chat_ops::{
delete_answer_message_by_question_message_id, insert_answer_message,
insert_answer_message_with_transaction, insert_chat, insert_question_message,
select_chat_messages,
select_chat_message_matching_reply_message_id, select_chat_messages,
};
use futures::stream::Stream;
use serde_json::json;
@ -232,3 +232,15 @@ pub async fn get_chat_messages(
txn.commit().await?;
Ok(messages)
}
pub async fn get_question_message(
pg_pool: &PgPool,
chat_id: &str,
answer_message_id: i64,
) -> Result<Option<ChatMessage>, AppError> {
let mut txn = pg_pool.begin().await?;
let message =
select_chat_message_matching_reply_message_id(&mut txn, chat_id, answer_message_id).await?;
txn.commit().await?;
Ok(message)
}

View file

@ -6,8 +6,8 @@ use client_api_test::{ai_test_enabled, TestClient};
use futures_util::StreamExt;
use serde_json::json;
use shared_entity::dto::chat_dto::{
ChatMessageMetadata, ChatRAGData, CreateChatMessageParams, CreateChatParams, MessageCursor,
UpdateChatParams,
ChatMessageMetadata, ChatRAGData, CreateAnswerMessageParams, CreateChatMessageParams,
CreateChatParams, MessageCursor, UpdateChatParams,
};
#[tokio::test]
@ -344,6 +344,10 @@ async fn create_chat_context_test() {
// #[tokio::test]
// async fn update_chat_message_test() {
// if !ai_test_enabled() {
// return;
// }
// let test_client = TestClient::new_user_without_ws_conn().await;
// let workspace_id = test_client.workspace_id().await;
// let chat_id = uuid::Uuid::new_v4().to_string();
@ -352,13 +356,13 @@ async fn create_chat_context_test() {
// name: "my second chat".to_string(),
// rag_ids: vec![],
// };
//
// test_client
// .api_client
// .create_chat(&workspace_id, params)
// .await
// .unwrap();
//
// let params = CreateChatMessageParams::new_user("where is singapore?");
// let stream = test_client
// .api_client
@ -367,7 +371,7 @@ async fn create_chat_context_test() {
// .unwrap();
// let messages: Vec<ChatMessage> = stream.map(|message| message.unwrap()).collect().await;
// assert_eq!(messages.len(), 2);
//
// let params = UpdateChatMessageContentParams {
// chat_id: chat_id.clone(),
// message_id: messages[0].message_id,
@ -378,7 +382,7 @@ async fn create_chat_context_test() {
// .update_chat_message(&workspace_id, &chat_id, params)
// .await
// .unwrap();
//
// let remote_messages = test_client
// .api_client
// .get_chat_messages(&workspace_id, &chat_id, MessageCursor::NextBack, 2)
@ -387,11 +391,69 @@ async fn create_chat_context_test() {
// .messages;
// assert_eq!(remote_messages[0].content, "where is China?");
// assert_eq!(remote_messages.len(), 2);
//
// // when the question was updated, the answer should be different
// assert_ne!(remote_messages[1].content, messages[1].content);
// }
#[tokio::test]
async fn get_question_message_test() {
if !ai_test_enabled() {
return;
}
let test_client = TestClient::new_user_without_ws_conn().await;
let workspace_id = test_client.workspace_id().await;
let chat_id = uuid::Uuid::new_v4().to_string();
let params = CreateChatParams {
chat_id: chat_id.clone(),
name: "my ai chat".to_string(),
rag_ids: vec![],
};
test_client
.api_client
.create_chat(&workspace_id, params)
.await
.unwrap();
let params = CreateChatMessageParams::new_user("where is singapore?");
let question = test_client
.api_client
.create_question(&workspace_id, &chat_id, params)
.await
.unwrap();
let answer = test_client
.api_client
.get_answer(&workspace_id, &chat_id, question.message_id)
.await
.unwrap();
test_client
.api_client
.save_answer(
&workspace_id,
&chat_id,
CreateAnswerMessageParams {
content: answer.content,
metadata: None,
question_message_id: question.message_id,
},
)
.await
.unwrap();
let find_question = test_client
.api_client
.get_question_message_from_answer_id(&workspace_id, &chat_id, answer.message_id)
.await
.unwrap()
.unwrap();
assert_eq!(find_question.reply_message_id.unwrap(), answer.message_id);
}
async fn collect_answer(mut stream: QuestionStream) -> String {
let mut answer = String::new();
while let Some(value) = stream.next().await {