chore: extend chat message with meta (#592)

This commit is contained in:
Nathan.fooo 2024-06-01 19:32:39 +08:00 committed by GitHub
parent 89030c420f
commit edfcb5c1ea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 121 additions and 10 deletions

View file

@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT message_id, content, created_at, author\n FROM af_chat_messages\n WHERE chat_id = $1\n ORDER BY created_at ASC\n ",
"query": "\n SELECT message_id, content, created_at, author, meta_data\n FROM af_chat_messages\n WHERE chat_id = $1\n ORDER BY created_at ASC\n ",
"describe": {
"columns": [
{
@ -22,6 +22,11 @@
"ordinal": 3,
"name": "author",
"type_info": "Jsonb"
},
{
"ordinal": 4,
"name": "meta_data",
"type_info": "Jsonb"
}
],
"parameters": {
@ -33,8 +38,9 @@
false,
false,
false,
false,
false
]
},
"hash": "533ef0f5237ca12ce0c6ca1dc938cc8dd34603b256f36dc683013264018332fc"
"hash": "69b0bbc97c37de47ad3f9cfe4f0099496470100d49903e34c8f773c43f094b2f"
}

View file

@ -32,6 +32,11 @@
"ordinal": 5,
"name": "workspace_id",
"type_info": "Uuid"
},
{
"ordinal": 6,
"name": "meta_data",
"type_info": "Jsonb"
}
],
"parameters": {
@ -45,6 +50,7 @@
true,
false,
false,
false,
false
]
},

2
Cargo.lock generated
View file

@ -568,6 +568,8 @@ name = "appflowy-ai-client"
version = "0.1.0"
dependencies = [
"anyhow",
"bytes",
"futures",
"reqwest 0.12.4",
"serde",
"serde_json",

View file

@ -6,13 +6,15 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
reqwest = { version = "0.12", features = ["json", "rustls-tls", "cookies"], optional = true }
reqwest = { version = "0.12", features = ["json", "rustls-tls", "cookies", "stream"], optional = true }
serde = { version = "1.0.199", features = ["derive"], optional = true }
serde_json = { version = "1.0", optional = true }
thiserror = "1.0.58"
anyhow = "1.0.81"
tracing = { version = "0.1", optional = true }
serde_repr = { version = "0.1", optional = true }
futures = "0.3.30"
bytes = "1.6.0"
[dev-dependencies]
tokio = { version = "1.37.0", features = ["macros", "test-util"] }

View file

@ -3,10 +3,14 @@ use crate::dto::{
SearchDocumentsRequest, SummarizeRowResponse, TranslateRowResponse,
};
use crate::error::AIError;
use anyhow::anyhow;
use futures::{Stream, StreamExt};
use reqwest;
use reqwest::{Method, RequestBuilder, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::{json, Map, Value};
use std::borrow::Cow;
use tracing::{info, trace};
#[derive(Clone, Debug)]
@ -137,6 +141,26 @@ impl AppFlowyAIClient {
.into_data()
}
pub async fn stream_question(
&self,
chat_id: &str,
content: &str,
) -> Result<impl Stream<Item = Result<String, AIError>>, AIError> {
let json = ChatQuestion {
chat_id: chat_id.to_string(),
data: MessageData {
content: content.to_string(),
},
};
let url = format!("{}/chat/stream_message", self.url);
let resp = self
.http_client(Method::POST, &url)?
.json(&json)
.send()
.await?;
AIResponse::<String>::stream_response(resp).await
}
fn http_client(&self, method: Method, url: &str) -> Result<RequestBuilder, AIError> {
let request_builder = self.client.request(method, url);
Ok(request_builder)
@ -174,8 +198,27 @@ where
Some(data) => Ok(data),
}
}
}
pub async fn stream_response(
resp: reqwest::Response,
) -> Result<impl Stream<Item = Result<String, AIError>>, AIError> {
let status_code = resp.status();
if !status_code.is_success() {
let body = resp.text().await?;
return Err(AIError::InvalidRequest(body));
}
let stream = resp.bytes_stream().map(|item| {
item
.map_err(|err| AIError::Internal(err.into()))
.and_then(|bytes| {
String::from_utf8(bytes.to_vec())
.map(|s| s.replace('\n', ""))
.map_err(|err| AIError::Internal(anyhow!("Parser AI response error: {:?}", err)))
})
});
Ok(stream)
}
}
impl From<reqwest::Error> for AIError {
fn from(error: reqwest::Error) -> Self {
if error.is_timeout() {

View file

@ -1,4 +1,5 @@
use crate::appflowy_ai_client;
use futures::stream::StreamExt;
#[tokio::test]
async fn qa_test() {
@ -11,3 +12,38 @@ async fn qa_test() {
.unwrap();
assert!(!resp.content.is_empty());
}
#[tokio::test]
async fn stop_steam_test() {
let client = appflowy_ai_client();
client.health_check().await.unwrap();
let chat_id = uuid::Uuid::new_v4().to_string();
let mut stream = client
.stream_question(&chat_id, "I feel hungry")
.await
.unwrap();
let mut count = 0;
while let Some(message) = stream.next().await {
if count > 1 {
break;
}
count += 1;
println!("message: {:?}", message);
}
assert_ne!(count, 0);
}
async fn steam_test() {
let client = appflowy_ai_client();
client.health_check().await.unwrap();
let chat_id = uuid::Uuid::new_v4().to_string();
let mut stream = client
.stream_question(&chat_id, "I feel hungry")
.await
.unwrap();
let messages: Vec<String> = stream.map(|message| message.unwrap()).collect().await;
println!("final answer: {}", messages.join(""));
}

View file

@ -644,6 +644,7 @@ pub struct ChatMessage {
pub message_id: i64,
pub content: String,
pub created_at: DateTime<Utc>,
pub meta_data: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]

View file

@ -155,6 +155,7 @@ pub async fn insert_chat_message<'a, E: Executor<'a, Database = Postgres>>(
message_id: row.message_id,
content,
created_at: row.created_at,
meta_data: Default::default(),
};
Ok(chat_message)
}
@ -203,20 +204,26 @@ pub async fn select_chat_messages(
},
}
let rows: Vec<(i64, String, DateTime<Utc>, serde_json::Value)> =
sqlx::query_as_with(&query, args)
.fetch_all(txn.deref_mut())
.await?;
let rows: Vec<(
i64,
String,
DateTime<Utc>,
serde_json::Value,
serde_json::Value,
)> = sqlx::query_as_with(&query, args)
.fetch_all(txn.deref_mut())
.await?;
let messages = rows
.into_iter()
.flat_map(|(message_id, content, created_at, author)| {
.flat_map(|(message_id, content, created_at, author, meta_data)| {
match serde_json::from_value::<ChatAuthor>(author) {
Ok(author) => Some(ChatMessage {
author,
message_id,
content,
created_at,
meta_data,
}),
Err(err) => {
warn!("Failed to deserialize author: {}", err);
@ -288,7 +295,7 @@ pub async fn get_all_chat_messages<'a, E: Executor<'a, Database = Postgres>>(
let rows = sqlx::query!(
// ChatMessage,
r#"
SELECT message_id, content, created_at, author
SELECT message_id, content, created_at, author, meta_data
FROM af_chat_messages
WHERE chat_id = $1
ORDER BY created_at ASC
@ -307,6 +314,7 @@ pub async fn get_all_chat_messages<'a, E: Executor<'a, Database = Postgres>>(
message_id: row.message_id,
content: row.content,
created_at: row.created_at,
meta_data: row.meta_data,
}),
Err(err) => {
warn!("Failed to deserialize author: {}", err);

View file

@ -204,6 +204,7 @@ pub struct AFChatRow {
pub deleted_at: Option<DateTime<Utc>>,
pub rag_ids: serde_json::Value,
pub workspace_id: Uuid,
pub meta_data: serde_json::Value,
}
#[derive(Debug, Clone, FromRow, Serialize, Deserialize)]
pub struct AFChatMessageRow {

View file

@ -0,0 +1,6 @@
-- Add migration script here
ALTER TABLE af_chat
ADD COLUMN meta_data JSONB DEFAULT '{}' NOT NULL;
ALTER TABLE af_chat_messages
ADD COLUMN meta_data JSONB DEFAULT '{}' NOT NULL;