chore: generate answer manually (#594)

* chore: generate answer manually

* chore: rename

* chore: return reply message id

* chore: save message

* chore: commit schema files
This commit is contained in:
Nathan.fooo 2024-06-03 08:06:23 +08:00 committed by GitHub
parent 1cc5b58254
commit b36715dc24
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 674 additions and 71 deletions

View file

@ -0,0 +1,22 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT reply_message_id\n FROM af_chat_messages\n WHERE message_id = $1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "reply_message_id",
"type_info": "Int8"
}
],
"parameters": {
"Left": [
"Int8"
]
},
"nullable": [
true
]
},
"hash": "4f5951e61713d04963524b84648c9ff8c7be05f0089f6fd26fc6e0e0afeae579"
}

View file

@ -0,0 +1,22 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT content\n FROM af_chat_messages\n WHERE message_id = $1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "content",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Int8"
]
},
"nullable": [
false
]
},
"hash": "5fd78d55ed9c4b866f1ce883c5cc89d4df0d2b5daf485a9957e71112d2682f9a"
}

View file

@ -0,0 +1,30 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO af_chat_messages (chat_id, author, content)\n VALUES ($1, $2, $3)\n RETURNING message_id, created_at\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "message_id",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "created_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Uuid",
"Jsonb",
"Text"
]
},
"nullable": [
false,
false
]
},
"hash": "6c0a058f5a2a53ad6f89fc7b9af214d7cd5026093bb3f1c94570c87441294977"
}

View file

@ -0,0 +1,14 @@
{
"db_name": "PostgreSQL",
"query": "\n DELETE FROM af_chat_messages\n WHERE message_id = $1\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Int8"
]
},
"nullable": []
},
"hash": "842243ea6ca59135ae539060ff37b80791e76aa268a44642ede515f315e80c01"
}

View file

@ -0,0 +1,16 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE af_chat_messages\n SET content = $2,\n author = $3,\n created_at = CURRENT_TIMESTAMP\n WHERE message_id = $1\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Int8",
"Text",
"Jsonb"
]
},
"nullable": []
},
"hash": "8b4af677dd62367e68cdbf9a462e183c702cf98113cacf56a6acb6d130d2a79a"
}

View file

@ -0,0 +1,15 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE af_chat_messages\n SET reply_message_id = $2\n WHERE message_id = $1\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Int8",
"Int8"
]
},
"nullable": []
},
"hash": "bbb3c31ea7e9c0a3bdabbc23b2730ee0254f38a7c1457f917c8f37f1e1aefa12"
}

View file

@ -0,0 +1,52 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT message_id, content, created_at, author, meta_data, reply_message_id\n FROM af_chat_messages\n WHERE message_id = $1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "message_id",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "content",
"type_info": "Text"
},
{
"ordinal": 2,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 3,
"name": "author",
"type_info": "Jsonb"
},
{
"ordinal": 4,
"name": "meta_data",
"type_info": "Jsonb"
},
{
"ordinal": 5,
"name": "reply_message_id",
"type_info": "Int8"
}
],
"parameters": {
"Left": [
"Int8"
]
},
"nullable": [
false,
false,
false,
false,
false,
true
]
},
"hash": "d1ab621e0b6e8bc24f8fa8cbb975ae3b7f9f366cac02d66b5291d7207295ca29"
}

View file

@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT message_id, content, created_at, author, meta_data\n FROM af_chat_messages\n WHERE chat_id = $1\n ORDER BY created_at ASC\n ",
"query": "\n SELECT message_id, content, created_at, author, meta_data, reply_message_id\n FROM af_chat_messages\n WHERE chat_id = $1\n ORDER BY created_at ASC\n ",
"describe": {
"columns": [
{
@ -27,6 +27,11 @@
"ordinal": 4,
"name": "meta_data",
"type_info": "Jsonb"
},
{
"ordinal": 5,
"name": "reply_message_id",
"type_info": "Int8"
}
],
"parameters": {
@ -39,8 +44,9 @@
false,
false,
false,
false
false,
true
]
},
"hash": "69b0bbc97c37de47ad3f9cfe4f0099496470100d49903e34c8f773c43f094b2f"
"hash": "d84ab58e78653688e7c392ffad00d6e039be5ccb9c5b99b7088cc41cfe981873"
}

View file

@ -152,7 +152,7 @@ impl AppFlowyAIClient {
content: content.to_string(),
},
};
let url = format!("{}/chat/stream_message", self.url);
let url = format!("{}/chat/message/stream", self.url);
let resp = self
.http_client(Method::POST, &url)?
.json(&json)

View file

@ -21,7 +21,7 @@ async fn qa_test() {
assert_eq!(questions.len(), 3)
}
#[tokio::test]
async fn stop_steam_test() {
async fn stop_stream_test() {
let client = appflowy_ai_client();
client.health_check().await.unwrap();
let chat_id = uuid::Uuid::new_v4().to_string();
@ -43,7 +43,7 @@ async fn stop_steam_test() {
}
#[tokio::test]
async fn steam_test() {
async fn stream_test() {
let client = appflowy_ai_client();
client.health_check().await.unwrap();
let chat_id = uuid::Uuid::new_v4().to_string();

View file

@ -2,6 +2,7 @@ use crate::http::log_request_id;
use crate::Client;
use database_entity::dto::{
ChatMessage, CreateChatMessageParams, CreateChatParams, MessageCursor, RepeatedChatMessage,
UpdateChatMessageContentParams,
};
use futures_core::Stream;
use reqwest::Method;
@ -58,6 +59,48 @@ impl Client {
log_request_id(&resp);
AppResponse::<ChatMessage>::stream_response(resp).await
}
pub async fn update_chat_message(
&self,
workspace_id: &str,
chat_id: &str,
params: UpdateChatMessageContentParams,
) -> Result<(), AppResponseError> {
let url = format!(
"{}/api/chat/{workspace_id}/{chat_id}/message",
self.base_url
);
let resp = self
.http_client_with_auth(Method::PUT, &url)
.await?
.json(&params)
.send()
.await?;
log_request_id(&resp);
AppResponse::<()>::from_response(resp).await?.into_error()
}
pub async fn generate_question_answer(
&self,
workspace_id: &str,
chat_id: &str,
message_id: i64,
) -> Result<ChatMessage, AppResponseError> {
let url = format!(
"{}/api/chat/{workspace_id}/{chat_id}/{message_id}/answer",
self.base_url
);
let resp = self
.http_client_with_auth(Method::GET, &url)
.await?
.send()
.await?;
log_request_id(&resp);
AppResponse::<ChatMessage>::from_response(resp)
.await?
.into_data()
}
pub async fn get_chat_related_question(
&self,
workspace_id: &str,

View file

@ -573,11 +573,18 @@ pub struct CreateChatMessageParams {
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpdateChatMessageParams {
pub struct UpdateChatMessageMetaParams {
pub message_id: i64,
pub meta_data: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpdateChatMessageContentParams {
pub chat_id: String,
pub message_id: i64,
pub content: String,
}
#[derive(Debug, Clone, Default, Serialize_repr, Deserialize_repr)]
#[repr(u8)]
pub enum ChatMessageType {
@ -651,6 +658,7 @@ pub struct ChatMessage {
pub content: String,
pub created_at: DateTime<Utc>,
pub meta_data: serde_json::Value,
pub reply_message_id: Option<i64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -723,3 +731,8 @@ pub enum EmbeddingContentType {
/// The plain text representation of the document.
PlainText = 0,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpdateChatMessageResponse {
pub answer: Option<ChatMessage>,
}

View file

@ -5,7 +5,8 @@ use app_error::AppError;
use chrono::{DateTime, Utc};
use database_entity::dto::{
ChatAuthor, ChatMessage, CreateChatParams, GetChatMessageParams, MessageCursor,
RepeatedChatMessage, UpdateChatMessageParams, UpdateChatParams,
RepeatedChatMessage, UpdateChatMessageContentParams, UpdateChatMessageMetaParams,
UpdateChatParams,
};
use serde_json::json;
@ -129,7 +130,129 @@ pub async fn select_chat<'a, E: Executor<'a, Database = Postgres>>(
}
}
pub async fn insert_chat_message<'a, E: Executor<'a, Database = Postgres>>(
pub async fn insert_answer_message_with_transaction(
transaction: &mut Transaction<'_, Postgres>,
author: ChatAuthor,
chat_id: &str,
content: String,
question_message_id: i64,
) -> Result<ChatMessage, AppError> {
let chat_id = Uuid::from_str(chat_id)?;
let existing_reply_message_id: Option<i64> = sqlx::query_scalar!(
r#"
SELECT reply_message_id
FROM af_chat_messages
WHERE message_id = $1
"#,
question_message_id
)
.fetch_one(transaction.deref_mut())
.await?;
if let Some(reply_id) = existing_reply_message_id {
// If there is an existing reply_message_id, update the existing message
sqlx::query!(
r#"
UPDATE af_chat_messages
SET content = $2,
author = $3,
created_at = CURRENT_TIMESTAMP
WHERE message_id = $1
"#,
reply_id,
&content,
json!(author),
)
.execute(transaction.deref_mut())
.await
.map_err(|err| AppError::Internal(anyhow!("Failed to update chat message: {}", err)))?;
let row = sqlx::query!(
r#"
SELECT message_id, content, created_at, author, meta_data, reply_message_id
FROM af_chat_messages
WHERE message_id = $1
"#,
reply_id
)
.fetch_one(transaction.deref_mut())
.await
.map_err(|err| AppError::Internal(anyhow!("Failed to fetch updated message: {}", err)))?;
let chat_message = ChatMessage {
author,
message_id: row.message_id,
content: row.content,
created_at: row.created_at,
meta_data: row.meta_data,
reply_message_id: Some(question_message_id),
};
Ok(chat_message)
} else {
// Insert a new chat message
let row = sqlx::query!(
r#"
INSERT INTO af_chat_messages (chat_id, author, content)
VALUES ($1, $2, $3)
RETURNING message_id, created_at
"#,
chat_id,
json!(author),
&content,
)
.fetch_one(transaction.deref_mut())
.await
.map_err(|err| AppError::Internal(anyhow!("Failed to insert chat message: {}", err)))?;
// Update the question message with the new reply_message_id
sqlx::query!(
r#"
UPDATE af_chat_messages
SET reply_message_id = $2
WHERE message_id = $1
"#,
question_message_id,
row.message_id,
)
.execute(transaction.deref_mut())
.await
.map_err(|err| AppError::Internal(anyhow!("Failed to update reply_message_id: {}", err)))?;
let chat_message = ChatMessage {
author,
message_id: row.message_id,
content,
created_at: row.created_at,
meta_data: Default::default(),
reply_message_id: None,
};
Ok(chat_message)
}
}
pub async fn insert_answer_message(
pg_pool: &PgPool,
author: ChatAuthor,
chat_id: &str,
content: String,
question_message_id: i64,
) -> Result<ChatMessage, AppError> {
let mut txn = pg_pool.begin().await?;
let chat_message =
insert_answer_message_with_transaction(&mut txn, author, chat_id, content, question_message_id)
.await?;
txn.commit().await.map_err(|err| {
AppError::Internal(anyhow!(
"Failed to commit transaction to insert answer message: {}",
err
))
})?;
Ok(chat_message)
}
pub async fn insert_question_message<'a, E: Executor<'a, Database = Postgres>>(
executor: E,
author: ChatAuthor,
chat_id: &str,
@ -156,6 +279,7 @@ pub async fn insert_chat_message<'a, E: Executor<'a, Database = Postgres>>(
content,
created_at: row.created_at,
meta_data: Default::default(),
reply_message_id: None,
};
Ok(chat_message)
}
@ -167,7 +291,7 @@ pub async fn select_chat_messages(
) -> Result<RepeatedChatMessage, AppError> {
let chat_id = Uuid::from_str(chat_id)?;
let mut query = r#"
SELECT message_id, content, created_at, author, meta_data
SELECT message_id, content, created_at, author, meta_data, reply_message_id
FROM af_chat_messages
WHERE chat_id = $1
"#
@ -204,33 +328,38 @@ pub async fn select_chat_messages(
},
}
#[allow(clippy::type_complexity)]
let rows: Vec<(
i64,
String,
DateTime<Utc>,
serde_json::Value,
serde_json::Value,
Option<i64>,
)> = sqlx::query_as_with(&query, args)
.fetch_all(txn.deref_mut())
.await?;
let messages = rows
.into_iter()
.flat_map(|(message_id, content, created_at, author, meta_data)| {
match serde_json::from_value::<ChatAuthor>(author) {
Ok(author) => Some(ChatMessage {
author,
message_id,
content,
created_at,
meta_data,
}),
Err(err) => {
warn!("Failed to deserialize author: {}", err);
None
},
}
})
.flat_map(
|(message_id, content, created_at, author, meta_data, reply_message_id)| {
match serde_json::from_value::<ChatAuthor>(author) {
Ok(author) => Some(ChatMessage {
author,
message_id,
content,
created_at,
meta_data,
reply_message_id,
}),
Err(err) => {
warn!("Failed to deserialize author: {}", err);
None
},
}
},
)
.collect::<Vec<ChatMessage>>();
let total = sqlx::query_scalar!(
@ -295,7 +424,7 @@ pub async fn get_all_chat_messages<'a, E: Executor<'a, Database = Postgres>>(
let rows = sqlx::query!(
// ChatMessage,
r#"
SELECT message_id, content, created_at, author, meta_data
SELECT message_id, content, created_at, author, meta_data, reply_message_id
FROM af_chat_messages
WHERE chat_id = $1
ORDER BY created_at ASC
@ -315,6 +444,7 @@ pub async fn get_all_chat_messages<'a, E: Executor<'a, Database = Postgres>>(
content: row.content,
created_at: row.created_at,
meta_data: row.meta_data,
reply_message_id: row.reply_message_id,
}),
Err(err) => {
warn!("Failed to deserialize author: {}", err);
@ -327,29 +457,98 @@ pub async fn get_all_chat_messages<'a, E: Executor<'a, Database = Postgres>>(
Ok(messages)
}
pub async fn update_chat_message(
pg_pool: &PgPool,
params: UpdateChatMessageParams,
pub async fn delete_answer_message_by_question_message_id(
transaction: &mut Transaction<'_, Postgres>,
message_id: i64,
) -> Result<(), AppError> {
for (key, value) in params.meta_data.iter() {
sqlx::query(
// Step 1: Get the reply_message_id of the chat message with the given message_id
let reply_message_id: Option<i64> = sqlx::query_scalar!(
r#"
SELECT reply_message_id
FROM af_chat_messages
WHERE message_id = $1
"#,
message_id
)
.fetch_one(transaction.deref_mut())
.await?;
if let Some(reply_id) = reply_message_id {
// Step 2: Delete the chat message with the reply_message_id
sqlx::query!(
r#"
UPDATE af_chat_messages
SET meta_data = jsonb_set(
COALESCE(meta_data, '{}'),
$2,
$3::jsonb,
true
)
WHERE id = $1
"#,
DELETE FROM af_chat_messages
WHERE message_id = $1
"#,
reply_id
)
.bind(params.message_id)
.bind(format!("{{{}}}", key))
.bind(value)
.execute(pg_pool)
.execute(transaction.deref_mut())
.await?;
}
Ok(())
}
pub async fn update_chat_message_content(
transaction: &mut Transaction<'_, Postgres>,
params: &UpdateChatMessageContentParams,
) -> Result<(), AppError> {
sqlx::query(
r#"
UPDATE af_chat_messages
SET content = $2,
edited_at = CURRENT_TIMESTAMP
WHERE message_id = $1
"#,
)
.bind(params.message_id)
.bind(&params.content)
.execute(transaction.deref_mut())
.await?;
Ok(())
}
pub async fn update_chat_message_meta(
transaction: &mut Transaction<'_, Postgres>,
params: &UpdateChatMessageMetaParams,
) -> Result<(), AppError> {
for (key, value) in params.meta_data.iter() {
sqlx::query(
r#"
UPDATE af_chat_messages
SET meta_data = jsonb_set(
COALESCE(meta_data, '{}'),
$2,
$3::jsonb,
true
)
WHERE message_id = $1
"#,
)
.bind(params.message_id)
.bind(format!("{{{}}}", key))
.bind(value)
.execute(transaction.deref_mut())
.await?;
}
Ok(())
}
pub async fn select_chat_message_content<'a, E: Executor<'a, Database = Postgres>>(
executor: E,
message_id: i64,
) -> Result<String, AppError> {
let row = sqlx::query!(
r#"
SELECT content
FROM af_chat_messages
WHERE message_id = $1
"#,
message_id,
)
.fetch_one(executor)
.await?;
Ok(row.content)
}

View file

@ -3,4 +3,5 @@ ALTER TABLE af_chat
ADD COLUMN meta_data JSONB DEFAULT '{}' NOT NULL;
ALTER TABLE af_chat_messages
ADD COLUMN meta_data JSONB DEFAULT '{}' NOT NULL;
ADD COLUMN meta_data JSONB DEFAULT '{}' NOT NULL,
ADD COLUMN reply_message_id BIGINT;

View file

@ -1,14 +1,16 @@
use crate::biz::chat::ops::{create_chat, create_chat_message, delete_chat, get_chat_messages};
use crate::biz::chat::ops::{
create_chat, create_chat_message, delete_chat, generate_chat_message_answer, get_chat_messages,
update_chat_message,
};
use crate::state::AppState;
use actix_web::web::{Data, Json};
use actix_web::{web, HttpResponse, Scope};
use app_error::AppError;
use appflowy_ai_client::dto::RepeatedRelatedQuestion;
use authentication::jwt::UserUuid;
use database::chat::chat_ops::update_chat_message;
use database_entity::dto::{
CreateChatMessageParams, CreateChatParams, GetChatMessageParams, MessageCursor,
RepeatedChatMessage, UpdateChatMessageParams,
ChatMessage, CreateChatMessageParams, CreateChatParams, GetChatMessageParams, MessageCursor,
RepeatedChatMessage, UpdateChatMessageContentParams,
};
use shared_entity::response::{AppResponse, JsonAppResponse};
use std::collections::HashMap;
@ -22,16 +24,18 @@ pub fn chat_scope() -> Scope {
.service(
web::resource("/{chat_id}")
.route(web::delete().to(delete_chat_handler))
.route(web::post().to(update_chat_handler))
.route(web::get().to(get_chat_message_handler)),
)
.service(
web::resource("/{chat_id}/{message_id}/related_question")
.route(web::get().to(get_related_message_handler)),
)
.service(
web::resource("/{chat_id}/{message_id}/answer").route(web::get().to(generate_answer_handler)),
)
.service(
web::resource("/{chat_id}/message")
.route(web::post().to(post_chat_message_handler))
.route(web::post().to(create_chat_message_handler))
.route(web::put().to(update_chat_message_handler)),
)
}
@ -56,16 +60,7 @@ async fn delete_chat_handler(
Ok(AppResponse::Ok().into())
}
async fn update_chat_handler(
path: web::Path<(String, String)>,
state: Data<AppState>,
) -> actix_web::Result<JsonAppResponse<()>> {
let (_workspace_id, chat_id) = path.into_inner();
delete_chat(&state.pg_pool, &chat_id).await?;
Ok(AppResponse::Ok().into())
}
async fn post_chat_message_handler(
async fn create_chat_message_handler(
state: Data<AppState>,
path: web::Path<(String, String)>,
payload: Json<CreateChatMessageParams>,
@ -96,10 +91,10 @@ async fn post_chat_message_handler(
async fn update_chat_message_handler(
state: Data<AppState>,
payload: Json<UpdateChatMessageParams>,
payload: Json<UpdateChatMessageContentParams>,
) -> actix_web::Result<JsonAppResponse<()>> {
let params = payload.into_inner();
update_chat_message(&state.pg_pool, params).await?;
update_chat_message(&state.pg_pool, params, state.ai_client.clone()).await?;
Ok(AppResponse::Ok().into())
}
@ -116,6 +111,21 @@ async fn get_related_message_handler(
Ok(AppResponse::Ok().with_data(resp).into())
}
async fn generate_answer_handler(
path: web::Path<(String, String, i64)>,
state: Data<AppState>,
) -> actix_web::Result<JsonAppResponse<ChatMessage>> {
let (_workspace_id, chat_id, message_id) = path.into_inner();
let message = generate_chat_message_answer(
&state.pg_pool,
state.ai_client.clone(),
message_id,
&chat_id,
)
.await?;
Ok(AppResponse::Ok().with_data(message).into())
}
#[instrument(level = "debug", skip_all, err)]
async fn get_chat_message_handler(
path: web::Path<(String, String)>,

View file

@ -1,13 +1,18 @@
use actix_web::web::Bytes;
use anyhow::anyhow;
use app_error::AppError;
use appflowy_ai_client::client::AppFlowyAIClient;
use async_stream::stream;
use database::chat;
use database::chat::chat_ops::{insert_chat, insert_chat_message, select_chat_messages};
use database::chat::chat_ops::{
delete_answer_message_by_question_message_id, insert_answer_message,
insert_answer_message_with_transaction, insert_chat, insert_question_message,
select_chat_messages,
};
use database_entity::dto::{
ChatAuthor, ChatAuthorType, ChatMessageType, CreateChatMessageParams, CreateChatParams,
GetChatMessageParams, RepeatedChatMessage,
ChatAuthor, ChatAuthorType, ChatMessage, ChatMessageType, CreateChatMessageParams,
CreateChatParams, GetChatMessageParams, RepeatedChatMessage, UpdateChatMessageContentParams,
};
use futures::stream::Stream;
use sqlx::PgPool;
@ -34,6 +39,65 @@ pub(crate) async fn delete_chat(pg_pool: &PgPool, chat_id: &str) -> Result<(), A
Ok(())
}
pub async fn update_chat_message(
pg_pool: &PgPool,
params: UpdateChatMessageContentParams,
ai_client: AppFlowyAIClient,
) -> Result<(), AppError> {
let mut txn = pg_pool.begin().await?;
delete_answer_message_by_question_message_id(&mut txn, params.message_id).await?;
chat::chat_ops::update_chat_message_content(&mut txn, &params).await?;
txn.commit().await.map_err(|err| {
AppError::Internal(anyhow!(
"Failed to commit transaction to update chat message: {}",
err
))
})?;
let new_answer = ai_client
.send_question(&params.chat_id, &params.content)
.await?;
let _answer = insert_answer_message(
pg_pool,
ChatAuthor::ai(),
&params.chat_id,
new_answer.content,
params.message_id,
)
.await?;
Ok(())
}
pub async fn generate_chat_message_answer(
pg_pool: &PgPool,
ai_client: AppFlowyAIClient,
question_message_id: i64,
chat_id: &str,
) -> Result<ChatMessage, AppError> {
let content = chat::chat_ops::select_chat_message_content(pg_pool, question_message_id).await?;
let new_answer = ai_client.send_question(chat_id, &content).await?;
// Save the answer to the database
let mut txn = pg_pool.begin().await?;
let message = insert_answer_message_with_transaction(
&mut txn,
ChatAuthor::ai(),
chat_id,
new_answer.content,
question_message_id,
)
.await?;
txn.commit().await.map_err(|err| {
AppError::Internal(anyhow!(
"Failed to commit transaction to update chat message: {}",
err
))
})?;
Ok(message)
}
pub async fn create_chat_message(
pg_pool: &PgPool,
uid: i64,
@ -46,7 +110,7 @@ pub async fn create_chat_message(
let pg_pool = pg_pool.clone();
let stream = stream! {
// Insert question message
let question = match insert_chat_message(
let question = match insert_question_message(
&pg_pool,
ChatAuthor::new(uid, ChatAuthorType::Human),
&chat_id,
@ -59,6 +123,7 @@ pub async fn create_chat_message(
}
};
let question_id = question.message_id;
let question_bytes = match serde_json::to_vec(&question) {
Ok(bytes) => bytes,
Err(err) => {
@ -81,7 +146,7 @@ pub async fn create_chat_message(
}
};
let answer = match insert_chat_message(&pg_pool, ChatAuthor::ai(), &chat_id, content.clone()).await {
let answer = match insert_answer_message(&pg_pool, ChatAuthor::ai(), &chat_id, content.clone(),question_id).await {
Ok(answer) => answer,
Err(err) => {
yield Err(err);

View file

@ -1,5 +1,8 @@
use client_api_test::TestClient;
use database_entity::dto::{ChatMessage, CreateChatMessageParams, CreateChatParams, MessageCursor};
use database_entity::dto::{
ChatMessage, CreateChatMessageParams, CreateChatParams, MessageCursor,
UpdateChatMessageContentParams,
};
use futures_util::StreamExt;
#[tokio::test]
@ -126,3 +129,95 @@ async fn chat_qa_test() {
assert_eq!(related_questions.items.len(), 3);
println!("related questions: {:?}", related_questions.items);
}
#[tokio::test]
async fn generate_chat_message_answer_test() {
let test_client = TestClient::new_user_without_ws_conn().await;
let workspace_id = test_client.workspace_id().await;
let chat_id = uuid::Uuid::new_v4().to_string();
let params = CreateChatParams {
chat_id: chat_id.clone(),
name: "my second chat".to_string(),
rag_ids: vec![],
};
test_client
.api_client
.create_chat(&workspace_id, params)
.await
.unwrap();
let params = CreateChatMessageParams::new_user("where is singapore?");
let stream = test_client
.api_client
.create_chat_message(&workspace_id, &chat_id, params)
.await
.unwrap();
let messages: Vec<ChatMessage> = stream.map(|message| message.unwrap()).collect().await;
assert_eq!(messages.len(), 2);
let answer = test_client
.api_client
.generate_question_answer(&workspace_id, &chat_id, messages[0].message_id)
.await
.unwrap();
let remote_messages = test_client
.api_client
.get_chat_messages(&workspace_id, &chat_id, MessageCursor::NextBack, 2)
.await
.unwrap()
.messages;
assert_eq!(remote_messages.len(), 2);
assert_eq!(remote_messages[1].message_id, answer.message_id);
}
#[tokio::test]
async fn update_chat_message_test() {
let test_client = TestClient::new_user_without_ws_conn().await;
let workspace_id = test_client.workspace_id().await;
let chat_id = uuid::Uuid::new_v4().to_string();
let params = CreateChatParams {
chat_id: chat_id.clone(),
name: "my second chat".to_string(),
rag_ids: vec![],
};
test_client
.api_client
.create_chat(&workspace_id, params)
.await
.unwrap();
let params = CreateChatMessageParams::new_user("where is singapore?");
let stream = test_client
.api_client
.create_chat_message(&workspace_id, &chat_id, params)
.await
.unwrap();
let messages: Vec<ChatMessage> = stream.map(|message| message.unwrap()).collect().await;
assert_eq!(messages.len(), 2);
let params = UpdateChatMessageContentParams {
chat_id: chat_id.clone(),
message_id: messages[0].message_id,
content: "where is China?".to_string(),
};
test_client
.api_client
.update_chat_message(&workspace_id, &chat_id, params)
.await
.unwrap();
let remote_messages = test_client
.api_client
.get_chat_messages(&workspace_id, &chat_id, MessageCursor::NextBack, 2)
.await
.unwrap()
.messages;
assert_eq!(remote_messages[0].content, "where is China?");
assert_eq!(remote_messages.len(), 2);
// when the question was updated, the answer should be different
assert_ne!(remote_messages[1].content, messages[1].content);
}

View file

@ -1,6 +1,6 @@
use crate::sql_test::util::{setup_db, test_create_user};
use database::chat::chat_ops::{
delete_chat, get_all_chat_messages, insert_chat, insert_chat_message, select_chat,
delete_chat, get_all_chat_messages, insert_chat, insert_question_message, select_chat,
select_chat_messages,
};
use database_entity::dto::{ChatAuthor, ChatAuthorType, CreateChatParams, GetChatMessageParams};
@ -91,7 +91,7 @@ async fn chat_message_crud_test(pool: PgPool) {
// create chat messages
for i in 0..5 {
let _ = insert_chat_message(
let _ = insert_question_message(
&pool,
ChatAuthor::new(0, ChatAuthorType::System),
&chat_id,