mirror of
https://github.com/AppFlowy-IO/AppFlowy-Cloud.git
synced 2025-04-19 03:24:42 -04:00
chore: extend chat message with meta (#592)
This commit is contained in:
parent
89030c420f
commit
edfcb5c1ea
10 changed files with 121 additions and 10 deletions
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT message_id, content, created_at, author\n FROM af_chat_messages\n WHERE chat_id = $1\n ORDER BY created_at ASC\n ",
|
||||
"query": "\n SELECT message_id, content, created_at, author, meta_data\n FROM af_chat_messages\n WHERE chat_id = $1\n ORDER BY created_at ASC\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
|
@ -22,6 +22,11 @@
|
|||
"ordinal": 3,
|
||||
"name": "author",
|
||||
"type_info": "Jsonb"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "meta_data",
|
||||
"type_info": "Jsonb"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
|
@ -33,8 +38,9 @@
|
|||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "533ef0f5237ca12ce0c6ca1dc938cc8dd34603b256f36dc683013264018332fc"
|
||||
"hash": "69b0bbc97c37de47ad3f9cfe4f0099496470100d49903e34c8f773c43f094b2f"
|
||||
}
|
|
@ -32,6 +32,11 @@
|
|||
"ordinal": 5,
|
||||
"name": "workspace_id",
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"ordinal": 6,
|
||||
"name": "meta_data",
|
||||
"type_info": "Jsonb"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
|
@ -45,6 +50,7 @@
|
|||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
|
|
2
Cargo.lock
generated
2
Cargo.lock
generated
|
@ -568,6 +568,8 @@ name = "appflowy-ai-client"
|
|||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bytes",
|
||||
"futures",
|
||||
"reqwest 0.12.4",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
|
|
@ -6,13 +6,15 @@ edition = "2021"
|
|||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
reqwest = { version = "0.12", features = ["json", "rustls-tls", "cookies"], optional = true }
|
||||
reqwest = { version = "0.12", features = ["json", "rustls-tls", "cookies", "stream"], optional = true }
|
||||
serde = { version = "1.0.199", features = ["derive"], optional = true }
|
||||
serde_json = { version = "1.0", optional = true }
|
||||
thiserror = "1.0.58"
|
||||
anyhow = "1.0.81"
|
||||
tracing = { version = "0.1", optional = true }
|
||||
serde_repr = { version = "0.1", optional = true }
|
||||
futures = "0.3.30"
|
||||
bytes = "1.6.0"
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { version = "1.37.0", features = ["macros", "test-util"] }
|
||||
|
|
|
@ -3,10 +3,14 @@ use crate::dto::{
|
|||
SearchDocumentsRequest, SummarizeRowResponse, TranslateRowResponse,
|
||||
};
|
||||
use crate::error::AIError;
|
||||
use anyhow::anyhow;
|
||||
use futures::{Stream, StreamExt};
|
||||
use reqwest;
|
||||
use reqwest::{Method, RequestBuilder, StatusCode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Map, Value};
|
||||
use std::borrow::Cow;
|
||||
|
||||
use tracing::{info, trace};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
|
@ -137,6 +141,26 @@ impl AppFlowyAIClient {
|
|||
.into_data()
|
||||
}
|
||||
|
||||
pub async fn stream_question(
|
||||
&self,
|
||||
chat_id: &str,
|
||||
content: &str,
|
||||
) -> Result<impl Stream<Item = Result<String, AIError>>, AIError> {
|
||||
let json = ChatQuestion {
|
||||
chat_id: chat_id.to_string(),
|
||||
data: MessageData {
|
||||
content: content.to_string(),
|
||||
},
|
||||
};
|
||||
let url = format!("{}/chat/stream_message", self.url);
|
||||
let resp = self
|
||||
.http_client(Method::POST, &url)?
|
||||
.json(&json)
|
||||
.send()
|
||||
.await?;
|
||||
AIResponse::<String>::stream_response(resp).await
|
||||
}
|
||||
|
||||
fn http_client(&self, method: Method, url: &str) -> Result<RequestBuilder, AIError> {
|
||||
let request_builder = self.client.request(method, url);
|
||||
Ok(request_builder)
|
||||
|
@ -174,8 +198,27 @@ where
|
|||
Some(data) => Ok(data),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stream_response(
|
||||
resp: reqwest::Response,
|
||||
) -> Result<impl Stream<Item = Result<String, AIError>>, AIError> {
|
||||
let status_code = resp.status();
|
||||
if !status_code.is_success() {
|
||||
let body = resp.text().await?;
|
||||
return Err(AIError::InvalidRequest(body));
|
||||
}
|
||||
let stream = resp.bytes_stream().map(|item| {
|
||||
item
|
||||
.map_err(|err| AIError::Internal(err.into()))
|
||||
.and_then(|bytes| {
|
||||
String::from_utf8(bytes.to_vec())
|
||||
.map(|s| s.replace('\n', ""))
|
||||
.map_err(|err| AIError::Internal(anyhow!("Parser AI response error: {:?}", err)))
|
||||
})
|
||||
});
|
||||
Ok(stream)
|
||||
}
|
||||
}
|
||||
impl From<reqwest::Error> for AIError {
|
||||
fn from(error: reqwest::Error) -> Self {
|
||||
if error.is_timeout() {
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use crate::appflowy_ai_client;
|
||||
use futures::stream::StreamExt;
|
||||
|
||||
#[tokio::test]
|
||||
async fn qa_test() {
|
||||
|
@ -11,3 +12,38 @@ async fn qa_test() {
|
|||
.unwrap();
|
||||
assert!(!resp.content.is_empty());
|
||||
}
|
||||
#[tokio::test]
|
||||
|
||||
async fn stop_steam_test() {
|
||||
let client = appflowy_ai_client();
|
||||
client.health_check().await.unwrap();
|
||||
let chat_id = uuid::Uuid::new_v4().to_string();
|
||||
let mut stream = client
|
||||
.stream_question(&chat_id, "I feel hungry")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut count = 0;
|
||||
while let Some(message) = stream.next().await {
|
||||
if count > 1 {
|
||||
break;
|
||||
}
|
||||
count += 1;
|
||||
println!("message: {:?}", message);
|
||||
}
|
||||
|
||||
assert_ne!(count, 0);
|
||||
}
|
||||
|
||||
async fn steam_test() {
|
||||
let client = appflowy_ai_client();
|
||||
client.health_check().await.unwrap();
|
||||
let chat_id = uuid::Uuid::new_v4().to_string();
|
||||
let mut stream = client
|
||||
.stream_question(&chat_id, "I feel hungry")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let messages: Vec<String> = stream.map(|message| message.unwrap()).collect().await;
|
||||
println!("final answer: {}", messages.join(""));
|
||||
}
|
||||
|
|
|
@ -644,6 +644,7 @@ pub struct ChatMessage {
|
|||
pub message_id: i64,
|
||||
pub content: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub meta_data: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
|
@ -155,6 +155,7 @@ pub async fn insert_chat_message<'a, E: Executor<'a, Database = Postgres>>(
|
|||
message_id: row.message_id,
|
||||
content,
|
||||
created_at: row.created_at,
|
||||
meta_data: Default::default(),
|
||||
};
|
||||
Ok(chat_message)
|
||||
}
|
||||
|
@ -203,20 +204,26 @@ pub async fn select_chat_messages(
|
|||
},
|
||||
}
|
||||
|
||||
let rows: Vec<(i64, String, DateTime<Utc>, serde_json::Value)> =
|
||||
sqlx::query_as_with(&query, args)
|
||||
.fetch_all(txn.deref_mut())
|
||||
.await?;
|
||||
let rows: Vec<(
|
||||
i64,
|
||||
String,
|
||||
DateTime<Utc>,
|
||||
serde_json::Value,
|
||||
serde_json::Value,
|
||||
)> = sqlx::query_as_with(&query, args)
|
||||
.fetch_all(txn.deref_mut())
|
||||
.await?;
|
||||
|
||||
let messages = rows
|
||||
.into_iter()
|
||||
.flat_map(|(message_id, content, created_at, author)| {
|
||||
.flat_map(|(message_id, content, created_at, author, meta_data)| {
|
||||
match serde_json::from_value::<ChatAuthor>(author) {
|
||||
Ok(author) => Some(ChatMessage {
|
||||
author,
|
||||
message_id,
|
||||
content,
|
||||
created_at,
|
||||
meta_data,
|
||||
}),
|
||||
Err(err) => {
|
||||
warn!("Failed to deserialize author: {}", err);
|
||||
|
@ -288,7 +295,7 @@ pub async fn get_all_chat_messages<'a, E: Executor<'a, Database = Postgres>>(
|
|||
let rows = sqlx::query!(
|
||||
// ChatMessage,
|
||||
r#"
|
||||
SELECT message_id, content, created_at, author
|
||||
SELECT message_id, content, created_at, author, meta_data
|
||||
FROM af_chat_messages
|
||||
WHERE chat_id = $1
|
||||
ORDER BY created_at ASC
|
||||
|
@ -307,6 +314,7 @@ pub async fn get_all_chat_messages<'a, E: Executor<'a, Database = Postgres>>(
|
|||
message_id: row.message_id,
|
||||
content: row.content,
|
||||
created_at: row.created_at,
|
||||
meta_data: row.meta_data,
|
||||
}),
|
||||
Err(err) => {
|
||||
warn!("Failed to deserialize author: {}", err);
|
||||
|
|
|
@ -204,6 +204,7 @@ pub struct AFChatRow {
|
|||
pub deleted_at: Option<DateTime<Utc>>,
|
||||
pub rag_ids: serde_json::Value,
|
||||
pub workspace_id: Uuid,
|
||||
pub meta_data: serde_json::Value,
|
||||
}
|
||||
#[derive(Debug, Clone, FromRow, Serialize, Deserialize)]
|
||||
pub struct AFChatMessageRow {
|
||||
|
|
6
migrations/20240531031836_chat_message_meta.sql
Normal file
6
migrations/20240531031836_chat_message_meta.sql
Normal file
|
@ -0,0 +1,6 @@
|
|||
-- Add migration script here
|
||||
ALTER TABLE af_chat
|
||||
ADD COLUMN meta_data JSONB DEFAULT '{}' NOT NULL;
|
||||
|
||||
ALTER TABLE af_chat_messages
|
||||
ADD COLUMN meta_data JSONB DEFAULT '{}' NOT NULL;
|
Loading…
Add table
Reference in a new issue