mirror of
https://github.com/AppFlowy-IO/AppFlowy-Cloud.git
synced 2025-04-19 03:24:42 -04:00
chore: enable chat with provided context (#713)
* chore: enable chat with provided context * chore: rename * chore: update create chat message api endpoint * chore: use list context * chore: use list context * chore: fix test * chore: update api endpoint * chore: rename client api function * chore: rename client api function * chore: expose entity * chore: update sqlx files * chore: update test
This commit is contained in:
parent
3b389d7911
commit
a371912c61
23 changed files with 766 additions and 125 deletions
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n INSERT INTO af_chat_messages (chat_id, author, content)\n VALUES ($1, $2, $3)\n RETURNING message_id, created_at\n ",
|
||||
"query": "\n INSERT INTO af_chat_messages (chat_id, author, content, meta_data)\n VALUES ($1, $2, $3, $4)\n RETURNING message_id, created_at\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
|
@ -18,7 +18,8 @@
|
|||
"Left": [
|
||||
"Uuid",
|
||||
"Jsonb",
|
||||
"Text"
|
||||
"Text",
|
||||
"Jsonb"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
|
@ -26,5 +27,5 @@
|
|||
false
|
||||
]
|
||||
},
|
||||
"hash": "6c0a058f5a2a53ad6f89fc7b9af214d7cd5026093bb3f1c94570c87441294977"
|
||||
"hash": "09ff850490eab213cfa0ad88ece9ce7baa39beabee19754fd993268d29552eb9"
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n INSERT INTO af_chat_messages (chat_id, author, content)\n VALUES ($1, $2, $3)\n RETURNING message_id, created_at\n ",
|
||||
"query": "\n INSERT INTO af_chat_messages (chat_id, author, content, meta_data)\n VALUES ($1, $2, $3, $4)\n RETURNING message_id, created_at\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
|
@ -18,7 +18,8 @@
|
|||
"Left": [
|
||||
"Uuid",
|
||||
"Jsonb",
|
||||
"Text"
|
||||
"Text",
|
||||
"Jsonb"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
|
@ -26,5 +27,5 @@
|
|||
false
|
||||
]
|
||||
},
|
||||
"hash": "95b4d7508569cac38c78d21a0a471772d3703e5678ee7ca0cd32d60f5343be91"
|
||||
"hash": "878399ad7c572c8c7c469229af15eadc7421ecc34844d05526196bc3a4147f6b"
|
||||
}
|
|
@ -1,16 +1,17 @@
|
|||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n UPDATE af_chat_messages\n SET content = $2,\n author = $3,\n created_at = CURRENT_TIMESTAMP\n WHERE message_id = $1\n ",
|
||||
"query": "\n UPDATE af_chat_messages\n SET content = $2,\n author = $3,\n created_at = CURRENT_TIMESTAMP,\n meta_data = $4\n WHERE message_id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8",
|
||||
"Text",
|
||||
"Jsonb",
|
||||
"Jsonb"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "8b4af677dd62367e68cdbf9a462e183c702cf98113cacf56a6acb6d130d2a79a"
|
||||
"hash": "da1434fe116cbb48bc5aac0b6905dd748f096bf78d3cdcfea3a576b4aaeba5fc"
|
||||
}
|
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -530,6 +530,7 @@ dependencies = [
|
|||
"appflowy-ai-client",
|
||||
"bytes",
|
||||
"futures",
|
||||
"pin-project",
|
||||
"reqwest 0.12.5",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
|
|
@ -15,6 +15,7 @@ tracing = { version = "0.1", optional = true }
|
|||
serde_repr = { version = "0.1", optional = true }
|
||||
futures = "0.3.30"
|
||||
bytes = "1.6.0"
|
||||
pin-project = "1.1.5"
|
||||
|
||||
[dev-dependencies]
|
||||
appflowy-ai-client = { path = ".", features = ["dto", "client-api"] }
|
||||
|
|
|
@ -1,20 +1,27 @@
|
|||
use crate::dto::{
|
||||
AIModel, ChatAnswer, ChatQuestion, CompleteTextResponse, CompletionType, Document,
|
||||
EmbeddingRequest, EmbeddingResponse, LocalAIConfig, MessageData, RepeatedLocalAIPackage,
|
||||
RepeatedRelatedQuestion, SearchDocumentsRequest, SummarizeRowResponse, TranslateRowData,
|
||||
TranslateRowResponse,
|
||||
AIModel, ChatAnswer, ChatQuestion, CompleteTextResponse, CompletionType, CreateTextChatContext,
|
||||
Document, EmbeddingRequest, EmbeddingResponse, LocalAIConfig, MessageData,
|
||||
RepeatedLocalAIPackage, RepeatedRelatedQuestion, SearchDocumentsRequest, SummarizeRowResponse,
|
||||
TranslateRowData, TranslateRowResponse,
|
||||
};
|
||||
use crate::error::AIError;
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures::{Stream, StreamExt};
|
||||
use futures::{ready, Stream, StreamExt, TryStreamExt};
|
||||
use reqwest;
|
||||
use reqwest::{Method, RequestBuilder, StatusCode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Map, Value};
|
||||
use serde_json::{json, Map, StreamDeserializer, Value};
|
||||
|
||||
use anyhow::anyhow;
|
||||
use pin_project::pin_project;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde_json::de::SliceRead;
|
||||
use std::borrow::Cow;
|
||||
use std::marker::PhantomData;
|
||||
use std::pin::Pin;
|
||||
|
||||
use std::task::{Context, Poll};
|
||||
use tracing::{info, trace};
|
||||
|
||||
const AI_MODEL_HEADER_KEY: &str = "ai-model";
|
||||
|
@ -174,6 +181,20 @@ impl AppFlowyAIClient {
|
|||
.into_data()
|
||||
}
|
||||
|
||||
pub async fn create_chat_text_context(
|
||||
&self,
|
||||
context: CreateTextChatContext,
|
||||
) -> Result<(), AIError> {
|
||||
let url = format!("{}/chat/context/text", self.url);
|
||||
let resp = self
|
||||
.http_client(Method::POST, &url)?
|
||||
.json(&context)
|
||||
.send()
|
||||
.await?;
|
||||
let _ = AIResponse::<()>::from_response(resp).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn send_question(
|
||||
&self,
|
||||
chat_id: &str,
|
||||
|
@ -220,6 +241,28 @@ impl AppFlowyAIClient {
|
|||
AIResponse::<()>::stream_response(resp).await
|
||||
}
|
||||
|
||||
pub async fn stream_question_v2(
|
||||
&self,
|
||||
chat_id: &str,
|
||||
content: &str,
|
||||
model: &AIModel,
|
||||
) -> Result<impl Stream<Item = Result<Bytes, AIError>>, AIError> {
|
||||
let json = ChatQuestion {
|
||||
chat_id: chat_id.to_string(),
|
||||
data: MessageData {
|
||||
content: content.to_string(),
|
||||
},
|
||||
};
|
||||
let url = format!("{}/v2/chat/message/stream", self.url);
|
||||
let resp = self
|
||||
.http_client(Method::POST, &url)?
|
||||
.header(AI_MODEL_HEADER_KEY, model.to_str())
|
||||
.json(&json)
|
||||
.send()
|
||||
.await?;
|
||||
AIResponse::<()>::stream_response(resp).await
|
||||
}
|
||||
|
||||
pub async fn get_related_question(
|
||||
&self,
|
||||
chat_id: &str,
|
||||
|
@ -307,6 +350,19 @@ where
|
|||
.map(|item| item.map_err(|err| AIError::Internal(err.into())));
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
pub async fn json_stream_response(
|
||||
resp: reqwest::Response,
|
||||
) -> Result<impl Stream<Item = Result<T, AIError>>, AIError> {
|
||||
let status_code = resp.status();
|
||||
if !status_code.is_success() {
|
||||
let body = resp.text().await?;
|
||||
return Err(AIError::Internal(anyhow!(body)));
|
||||
}
|
||||
|
||||
let stream = resp.bytes_stream().map_err(AIError::from);
|
||||
Ok(JsonStream::new(stream))
|
||||
}
|
||||
}
|
||||
impl From<reqwest::Error> for AIError {
|
||||
fn from(error: reqwest::Error) -> Self {
|
||||
|
@ -324,3 +380,62 @@ impl From<reqwest::Error> for AIError {
|
|||
AIError::Internal(error.into())
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
pub struct JsonStream<T> {
|
||||
stream: Pin<Box<dyn Stream<Item = Result<Bytes, AIError>> + Send>>,
|
||||
buffer: Vec<u8>,
|
||||
_marker: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T> JsonStream<T> {
|
||||
pub fn new<S>(stream: S) -> Self
|
||||
where
|
||||
S: Stream<Item = Result<Bytes, AIError>> + Send + 'static,
|
||||
{
|
||||
JsonStream {
|
||||
stream: Box::pin(stream),
|
||||
buffer: Vec::new(),
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Stream for JsonStream<T>
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
type Item = Result<T, AIError>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.project();
|
||||
|
||||
match ready!(this.stream.as_mut().poll_next(cx)) {
|
||||
Some(Ok(bytes)) => {
|
||||
this.buffer.extend_from_slice(&bytes);
|
||||
let de = StreamDeserializer::new(SliceRead::new(this.buffer));
|
||||
let mut iter = de.into_iter();
|
||||
if let Some(result) = iter.next() {
|
||||
match result {
|
||||
Ok(value) => {
|
||||
let remaining = iter.byte_offset();
|
||||
this.buffer.drain(0..remaining);
|
||||
Poll::Ready(Some(Ok(value)))
|
||||
},
|
||||
Err(err) => {
|
||||
if err.is_eof() {
|
||||
Poll::Pending
|
||||
} else {
|
||||
Poll::Ready(Some(Err(AIError::Internal(err.into()))))
|
||||
}
|
||||
},
|
||||
}
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
},
|
||||
Some(Err(err)) => Poll::Ready(Some(Err(err))),
|
||||
None => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,6 +4,8 @@ use std::collections::HashMap;
|
|||
use std::fmt::{Display, Formatter};
|
||||
use std::str::FromStr;
|
||||
|
||||
pub const STEAM_METADATA_KEY: &str = "0";
|
||||
pub const STEAM_ANSWER_KEY: &str = "1";
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct SummarizeRowResponse {
|
||||
pub text: String,
|
||||
|
@ -23,6 +25,8 @@ pub struct MessageData {
|
|||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ChatAnswer {
|
||||
pub content: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
|
@ -266,3 +270,26 @@ pub struct LocalAIConfig {
|
|||
pub models: Vec<LLMModel>,
|
||||
pub plugin: AppFlowyOfflineAI,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct CreateTextChatContext {
|
||||
pub chat_id: String,
|
||||
/// Only support "txt" and "md" for now
|
||||
pub content_type: String,
|
||||
pub text: String,
|
||||
pub chunk_size: i32,
|
||||
pub chunk_overlap: i32,
|
||||
pub metadata: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
impl Display for CreateTextChatContext {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_fmt(format_args!(
|
||||
"Create Chat context: {{ chat_id: {}, content_type: {}, content size: {}, metadata: {:?} }}",
|
||||
self.chat_id,
|
||||
self.content_type,
|
||||
self.text.len(),
|
||||
self.metadata
|
||||
))
|
||||
}
|
||||
}
|
||||
|
|
23
libs/appflowy-ai-client/tests/chat_test/context_test.rs
Normal file
23
libs/appflowy-ai-client/tests/chat_test/context_test.rs
Normal file
|
@ -0,0 +1,23 @@
|
|||
use crate::appflowy_ai_client;
|
||||
use appflowy_ai_client::dto::{AIModel, CreateTextChatContext};
|
||||
#[tokio::test]
|
||||
async fn create_chat_context_test() {
|
||||
let client = appflowy_ai_client();
|
||||
let chat_id = uuid::Uuid::new_v4().to_string();
|
||||
let context = CreateTextChatContext {
|
||||
chat_id: chat_id.clone(),
|
||||
content_type: "txt".to_string(),
|
||||
text: "I have lived in the US for five years".to_string(),
|
||||
chunk_size: 1000,
|
||||
chunk_overlap: 20,
|
||||
metadata: Default::default(),
|
||||
};
|
||||
client.create_chat_text_context(context).await.unwrap();
|
||||
let resp = client
|
||||
.send_question(&chat_id, "Where I live?", &AIModel::GPT35)
|
||||
.await
|
||||
.unwrap();
|
||||
// response will be something like:
|
||||
// Based on the context you provided, you have lived in the US for five years. Therefore, it is likely that you currently live in the US
|
||||
assert!(!resp.content.is_empty());
|
||||
}
|
|
@ -1,3 +1,4 @@
|
|||
mod completion_test;
|
||||
mod context_test;
|
||||
mod embedding_test;
|
||||
mod qa_test;
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use crate::appflowy_ai_client;
|
||||
|
||||
use appflowy_ai_client::dto::AIModel;
|
||||
use appflowy_ai_client::client::JsonStream;
|
||||
use appflowy_ai_client::dto::{AIModel, STEAM_ANSWER_KEY};
|
||||
use futures::stream::StreamExt;
|
||||
|
||||
#[tokio::test]
|
||||
|
@ -50,19 +51,23 @@ async fn stream_test() {
|
|||
client.health_check().await.unwrap();
|
||||
let chat_id = uuid::Uuid::new_v4().to_string();
|
||||
let stream = client
|
||||
.stream_question(&chat_id, "I feel hungry", &AIModel::GPT35)
|
||||
.stream_question_v2(&chat_id, "I feel hungry", &AIModel::GPT35)
|
||||
.await
|
||||
.unwrap();
|
||||
let json_stream = JsonStream::<serde_json::Value>::new(stream);
|
||||
|
||||
let stream = stream.map(|item| {
|
||||
item.map(|bytes| {
|
||||
String::from_utf8(bytes.to_vec())
|
||||
.map(|s| s.replace('\n', ""))
|
||||
.unwrap()
|
||||
let messages: Vec<String> = json_stream
|
||||
.filter_map(|item| async {
|
||||
match item {
|
||||
Ok(value) => value
|
||||
.get(STEAM_ANSWER_KEY)
|
||||
.and_then(|s| s.as_str().map(ToString::to_string)),
|
||||
Err(_) => None,
|
||||
}
|
||||
})
|
||||
});
|
||||
.collect()
|
||||
.await;
|
||||
|
||||
let messages: Vec<String> = stream.map(|message| message.unwrap()).collect().await;
|
||||
println!("final answer: {}", messages.join(""));
|
||||
}
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ use std::sync::Arc;
|
|||
use uuid::Uuid;
|
||||
|
||||
use client_api::error::ErrorCode;
|
||||
use console_error_panic_hook;
|
||||
|
||||
use database_entity::dto::QueryCollab;
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
|
|
|
@ -5,10 +5,17 @@ use client_api_entity::{
|
|||
ChatMessage, CreateAnswerMessageParams, CreateChatMessageParams, CreateChatParams, MessageCursor,
|
||||
RepeatedChatMessage, UpdateChatMessageContentParams,
|
||||
};
|
||||
use futures_core::Stream;
|
||||
use futures_core::{ready, Stream};
|
||||
use pin_project::pin_project;
|
||||
use reqwest::Method;
|
||||
use shared_entity::dto::ai_dto::RepeatedRelatedQuestion;
|
||||
use serde_json::Value;
|
||||
use shared_entity::dto::ai_dto::{
|
||||
CreateTextChatContext, RepeatedRelatedQuestion, STEAM_ANSWER_KEY, STEAM_METADATA_KEY,
|
||||
};
|
||||
use shared_entity::response::{AppResponse, AppResponseError};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tracing::error;
|
||||
|
||||
impl Client {
|
||||
/// Create a new chat
|
||||
|
@ -45,7 +52,7 @@ impl Client {
|
|||
}
|
||||
|
||||
/// Save a question message to a chat
|
||||
pub async fn save_question(
|
||||
pub async fn create_question(
|
||||
&self,
|
||||
workspace_id: &str,
|
||||
chat_id: &str,
|
||||
|
@ -91,14 +98,14 @@ impl Client {
|
|||
}
|
||||
|
||||
/// Ask AI with a question for given question's message_id
|
||||
pub async fn ask_question(
|
||||
pub async fn stream_answer(
|
||||
&self,
|
||||
workspace_id: &str,
|
||||
chat_id: &str,
|
||||
message_id: i64,
|
||||
question_message_id: i64,
|
||||
) -> Result<impl Stream<Item = Result<Bytes, AppResponseError>>, AppResponseError> {
|
||||
let url = format!(
|
||||
"{}/api/chat/{workspace_id}/{chat_id}/{message_id}/answer/stream",
|
||||
"{}/api/chat/{workspace_id}/{chat_id}/{question_message_id}/answer/stream",
|
||||
self.base_url
|
||||
);
|
||||
let resp = self
|
||||
|
@ -110,16 +117,36 @@ impl Client {
|
|||
AppResponse::<()>::answer_response_stream(resp).await
|
||||
}
|
||||
|
||||
/// Generate an answer for given question's message_id. The same as ask_question but return ChatMessage
|
||||
/// instead of stream of Bytes
|
||||
pub async fn generate_answer(
|
||||
pub async fn stream_answer_v2(
|
||||
&self,
|
||||
workspace_id: &str,
|
||||
chat_id: &str,
|
||||
message_id: i64,
|
||||
question_message_id: i64,
|
||||
) -> Result<QuestionStream, AppResponseError> {
|
||||
let url = format!(
|
||||
"{}/api/chat/{workspace_id}/{chat_id}/{question_message_id}/v2/answer/stream",
|
||||
self.base_url
|
||||
);
|
||||
let resp = self
|
||||
.http_client_with_auth(Method::GET, &url)
|
||||
.await?
|
||||
.send()
|
||||
.await?;
|
||||
log_request_id(&resp);
|
||||
let stream = AppResponse::<serde_json::Value>::json_response_stream(resp).await?;
|
||||
Ok(QuestionStream::new(stream))
|
||||
}
|
||||
|
||||
/// Generate an answer for given question's message_id. The same as ask_question but return ChatMessage
|
||||
/// instead of stream of Bytes
|
||||
pub async fn get_answer(
|
||||
&self,
|
||||
workspace_id: &str,
|
||||
chat_id: &str,
|
||||
question_message_id: i64,
|
||||
) -> Result<ChatMessage, AppResponseError> {
|
||||
let url = format!(
|
||||
"{}/api/chat/{workspace_id}/{chat_id}/{message_id}/answer",
|
||||
"{}/api/chat/{workspace_id}/{chat_id}/{question_message_id}/answer",
|
||||
self.base_url
|
||||
);
|
||||
let resp = self
|
||||
|
@ -231,4 +258,79 @@ impl Client {
|
|||
log_request_id(&resp);
|
||||
AppResponse::<ChatMessage>::json_response_stream(resp).await
|
||||
}
|
||||
|
||||
pub async fn create_chat_context(
|
||||
&self,
|
||||
workspace_id: &str,
|
||||
params: CreateTextChatContext,
|
||||
) -> Result<(), AppResponseError> {
|
||||
let url = format!(
|
||||
"{}/api/chat/{workspace_id}/{}/context/text",
|
||||
self.base_url, params.chat_id
|
||||
);
|
||||
let resp = self
|
||||
.http_client_with_auth(Method::POST, &url)
|
||||
.await?
|
||||
.json(¶ms)
|
||||
.send()
|
||||
.await?;
|
||||
log_request_id(&resp);
|
||||
AppResponse::<()>::from_response(resp).await?.into_error()
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
pub struct QuestionStream {
|
||||
stream: Pin<Box<dyn Stream<Item = Result<serde_json::Value, AppResponseError>> + Send>>,
|
||||
buffer: Vec<u8>,
|
||||
}
|
||||
|
||||
impl QuestionStream {
|
||||
pub fn new<S>(stream: S) -> Self
|
||||
where
|
||||
S: Stream<Item = Result<serde_json::Value, AppResponseError>> + Send + 'static,
|
||||
{
|
||||
QuestionStream {
|
||||
stream: Box::pin(stream),
|
||||
buffer: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum QuestionStreamValue {
|
||||
Answer { value: String },
|
||||
Metadata { value: serde_json::Value },
|
||||
}
|
||||
impl Stream for QuestionStream {
|
||||
type Item = Result<QuestionStreamValue, AppResponseError>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.project();
|
||||
|
||||
return match ready!(this.stream.as_mut().poll_next(cx)) {
|
||||
Some(Ok(value)) => match value {
|
||||
Value::Object(mut value) => {
|
||||
if let Some(metadata) = value.remove(STEAM_METADATA_KEY) {
|
||||
return Poll::Ready(Some(Ok(QuestionStreamValue::Metadata { value: metadata })));
|
||||
}
|
||||
|
||||
if let Some(answer) = value
|
||||
.remove(STEAM_ANSWER_KEY)
|
||||
.and_then(|s| s.as_str().map(ToString::to_string))
|
||||
{
|
||||
return Poll::Ready(Some(Ok(QuestionStreamValue::Answer { value: answer })));
|
||||
}
|
||||
|
||||
error!("Invalid streaming value: {:?}", value);
|
||||
Poll::Ready(None)
|
||||
},
|
||||
_ => {
|
||||
error!("Unexpected JSON value type: {:?}", value);
|
||||
Poll::Ready(None)
|
||||
},
|
||||
},
|
||||
Some(Err(err)) => Poll::Ready(Some(Err(err))),
|
||||
None => Poll::Ready(None),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,6 +26,7 @@ pub use wasm::*;
|
|||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
mod http_chat;
|
||||
|
||||
mod http_search;
|
||||
mod http_settings;
|
||||
pub mod ws;
|
||||
|
@ -37,6 +38,8 @@ pub mod error {
|
|||
|
||||
// Export all dto entities that will be used in the frontend application
|
||||
pub mod entity {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub use crate::http_chat::{QuestionStream, QuestionStreamValue};
|
||||
pub use client_api_entity::*;
|
||||
}
|
||||
|
||||
|
|
|
@ -648,6 +648,35 @@ pub struct CreateChatMessageParams {
|
|||
#[validate(custom = "validate_not_empty_str")]
|
||||
pub content: String,
|
||||
pub message_type: ChatMessageType,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessageMetadata {
|
||||
pub data: ChatMetadataData,
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub source: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMetadataData {
|
||||
/// Don't rename this field, it's used [ops::extract_chat_message_metadata]
|
||||
content: String,
|
||||
pub content_type: String,
|
||||
size: i64,
|
||||
}
|
||||
|
||||
impl ChatMetadataData {
|
||||
pub fn new_text(content: String) -> Self {
|
||||
let size = content.len();
|
||||
Self {
|
||||
content,
|
||||
content_type: "text".to_string(),
|
||||
size: size as i64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
@ -678,6 +707,7 @@ impl CreateChatMessageParams {
|
|||
Self {
|
||||
content: content.to_string(),
|
||||
message_type: ChatMessageType::System,
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -685,8 +715,14 @@ impl CreateChatMessageParams {
|
|||
Self {
|
||||
content: content.to_string(),
|
||||
message_type: ChatMessageType::User,
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
|
||||
self.metadata = Some(metadata);
|
||||
self
|
||||
}
|
||||
}
|
||||
#[derive(Debug, Clone, Validate, Serialize, Deserialize)]
|
||||
pub struct GetChatMessageParams {
|
||||
|
@ -828,6 +864,9 @@ pub struct CreateAnswerMessageParams {
|
|||
#[validate(custom = "validate_not_empty_str")]
|
||||
pub content: String,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
|
||||
pub question_message_id: i64,
|
||||
}
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ use database_entity::dto::{
|
|||
|
||||
use serde_json::json;
|
||||
use sqlx::postgres::PgArguments;
|
||||
use sqlx::types::JsonValue;
|
||||
use sqlx::{Arguments, Executor, PgPool, Postgres, Transaction};
|
||||
use std::ops::DerefMut;
|
||||
use std::str::FromStr;
|
||||
|
@ -135,6 +136,7 @@ pub async fn insert_answer_message_with_transaction(
|
|||
author: ChatAuthor,
|
||||
chat_id: &str,
|
||||
content: String,
|
||||
metadata: serde_json::Value,
|
||||
question_message_id: i64,
|
||||
) -> Result<ChatMessage, AppError> {
|
||||
let chat_id = Uuid::from_str(chat_id)?;
|
||||
|
@ -156,12 +158,14 @@ pub async fn insert_answer_message_with_transaction(
|
|||
UPDATE af_chat_messages
|
||||
SET content = $2,
|
||||
author = $3,
|
||||
created_at = CURRENT_TIMESTAMP
|
||||
created_at = CURRENT_TIMESTAMP,
|
||||
meta_data = $4
|
||||
WHERE message_id = $1
|
||||
"#,
|
||||
reply_id,
|
||||
&content,
|
||||
json!(author),
|
||||
metadata,
|
||||
)
|
||||
.execute(transaction.deref_mut())
|
||||
.await
|
||||
|
@ -193,13 +197,14 @@ pub async fn insert_answer_message_with_transaction(
|
|||
// Insert a new chat message
|
||||
let row = sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO af_chat_messages (chat_id, author, content)
|
||||
VALUES ($1, $2, $3)
|
||||
INSERT INTO af_chat_messages (chat_id, author, content, meta_data)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING message_id, created_at
|
||||
"#,
|
||||
chat_id,
|
||||
json!(author),
|
||||
&content,
|
||||
&metadata,
|
||||
)
|
||||
.fetch_one(transaction.deref_mut())
|
||||
.await
|
||||
|
@ -224,7 +229,7 @@ pub async fn insert_answer_message_with_transaction(
|
|||
message_id: row.message_id,
|
||||
content,
|
||||
created_at: row.created_at,
|
||||
meta_data: Default::default(),
|
||||
meta_data: metadata,
|
||||
reply_message_id: None,
|
||||
};
|
||||
|
||||
|
@ -237,12 +242,19 @@ pub async fn insert_answer_message(
|
|||
author: ChatAuthor,
|
||||
chat_id: &str,
|
||||
content: String,
|
||||
metadata: Option<serde_json::Value>,
|
||||
question_message_id: i64,
|
||||
) -> Result<ChatMessage, AppError> {
|
||||
let mut txn = pg_pool.begin().await?;
|
||||
let chat_message =
|
||||
insert_answer_message_with_transaction(&mut txn, author, chat_id, content, question_message_id)
|
||||
.await?;
|
||||
let chat_message = insert_answer_message_with_transaction(
|
||||
&mut txn,
|
||||
author,
|
||||
chat_id,
|
||||
content,
|
||||
metadata.unwrap_or_default(),
|
||||
question_message_id,
|
||||
)
|
||||
.await?;
|
||||
txn.commit().await.map_err(|err| {
|
||||
AppError::Internal(anyhow!(
|
||||
"Failed to commit transaction to insert answer message: {}",
|
||||
|
@ -257,17 +269,20 @@ pub async fn insert_question_message<'a, E: Executor<'a, Database = Postgres>>(
|
|||
author: ChatAuthor,
|
||||
chat_id: &str,
|
||||
content: String,
|
||||
metadata: Option<JsonValue>,
|
||||
) -> Result<ChatMessage, AppError> {
|
||||
let metadata = metadata.unwrap_or_else(|| json!({}));
|
||||
let chat_id = Uuid::from_str(chat_id)?;
|
||||
let row = sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO af_chat_messages (chat_id, author, content)
|
||||
VALUES ($1, $2, $3)
|
||||
INSERT INTO af_chat_messages (chat_id, author, content, meta_data)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING message_id, created_at
|
||||
"#,
|
||||
chat_id,
|
||||
json!(author),
|
||||
&content,
|
||||
&metadata,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
|
@ -278,7 +293,7 @@ pub async fn insert_question_message<'a, E: Executor<'a, Database = Postgres>>(
|
|||
message_id: row.message_id,
|
||||
content,
|
||||
created_at: row.created_at,
|
||||
meta_data: Default::default(),
|
||||
meta_data: metadata,
|
||||
reply_message_id: None,
|
||||
};
|
||||
Ok(chat_message)
|
||||
|
|
|
@ -89,32 +89,32 @@ where
|
|||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.project();
|
||||
|
||||
loop {
|
||||
match ready!(this.stream.as_mut().poll_next(cx)) {
|
||||
Some(Ok(bytes)) => {
|
||||
this.buffer.extend_from_slice(&bytes);
|
||||
let de = StreamDeserializer::new(SliceRead::new(this.buffer));
|
||||
let mut iter = de.into_iter();
|
||||
if let Some(result) = iter.next() {
|
||||
match result {
|
||||
Ok(value) => {
|
||||
let remaining = iter.byte_offset();
|
||||
this.buffer.drain(0..remaining);
|
||||
return Poll::Ready(Some(Ok(value)));
|
||||
},
|
||||
Err(err) => {
|
||||
if err.is_eof() {
|
||||
continue;
|
||||
} else {
|
||||
return Poll::Ready(Some(Err(AppResponseError::from(err))));
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
Some(Err(err)) => return Poll::Ready(Some(Err(err))),
|
||||
None => return Poll::Ready(None),
|
||||
}
|
||||
match ready!(this.stream.as_mut().poll_next(cx)) {
|
||||
Some(Ok(bytes)) => {
|
||||
this.buffer.extend_from_slice(&bytes);
|
||||
let de = StreamDeserializer::new(SliceRead::new(this.buffer));
|
||||
let mut iter = de.into_iter();
|
||||
if let Some(result) = iter.next() {
|
||||
return match result {
|
||||
Ok(value) => {
|
||||
let remaining = iter.byte_offset();
|
||||
this.buffer.drain(0..remaining);
|
||||
Poll::Ready(Some(Ok(value)))
|
||||
},
|
||||
Err(err) => {
|
||||
if err.is_eof() {
|
||||
Poll::Pending
|
||||
} else {
|
||||
Poll::Ready(Some(Err(AppResponseError::from(err))))
|
||||
}
|
||||
},
|
||||
};
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
},
|
||||
Some(Err(err)) => Poll::Ready(Some(Err(err))),
|
||||
None => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
111
src/api/chat.rs
111
src/api/chat.rs
|
@ -1,13 +1,14 @@
|
|||
use crate::biz::chat::ops::{
|
||||
create_chat, create_chat_message, create_chat_question, delete_chat,
|
||||
generate_chat_message_answer, get_chat_messages, update_chat_message,
|
||||
create_chat, create_chat_message, create_chat_message_stream, delete_chat,
|
||||
extract_chat_message_metadata, generate_chat_message_answer, get_chat_messages,
|
||||
update_chat_message, ExtractChatMetadata,
|
||||
};
|
||||
use crate::state::AppState;
|
||||
use actix_web::web::{Data, Json};
|
||||
use actix_web::{web, HttpRequest, HttpResponse, Scope};
|
||||
|
||||
use app_error::AppError;
|
||||
use appflowy_ai_client::dto::RepeatedRelatedQuestion;
|
||||
use appflowy_ai_client::dto::{CreateTextChatContext, RepeatedRelatedQuestion};
|
||||
use authentication::jwt::UserUuid;
|
||||
use bytes::Bytes;
|
||||
use database_entity::dto::{
|
||||
|
@ -52,20 +53,32 @@ pub fn chat_scope() -> Scope {
|
|||
.route(web::put().to(update_chat_message_handler)),
|
||||
)
|
||||
.service(
|
||||
// Create a question for given chat
|
||||
// Creating a [ChatMessage] for given content.
|
||||
// When client asks a question, it will use this API to create a chat message
|
||||
web::resource("/{chat_id}/message/question").route(web::post().to(create_question_handler)),
|
||||
)
|
||||
// create an answer for given chat
|
||||
.service(web::resource("/{chat_id}/message/answer").route(web::post().to(create_answer_handler)))
|
||||
// Writing the final answer for a given chat.
|
||||
// After the streaming is finished, the client will use this API to save the message to disk.
|
||||
.service(web::resource("/{chat_id}/message/answer").route(web::post().to(save_answer_handler)))
|
||||
.service(
|
||||
// Generate answer for given question.
|
||||
web::resource("/{chat_id}/{message_id}/answer").route(web::get().to(gen_answer_handler)),
|
||||
// Use AI to generate a response for a specified message ID.
|
||||
// To generate an answer for a given question, use "/answer/stream" to receive the answer in a stream.
|
||||
web::resource("/{chat_id}/{message_id}/answer").route(web::get().to(answer_handler)),
|
||||
)
|
||||
// Stream the answer for given question.
|
||||
// Use AI to generate a response for a specified message ID. This response will be return as a stream.
|
||||
.service(
|
||||
web::resource("/{chat_id}/{message_id}/answer/stream")
|
||||
.route(web::get().to(answer_stream_handler)),
|
||||
)
|
||||
.service(
|
||||
web::resource("/{chat_id}/{message_id}/v2/answer/stream")
|
||||
.route(web::get().to(answer_stream_v2_handler)),
|
||||
)
|
||||
.service(
|
||||
// Create chat context for a given chat.
|
||||
web::resource("/{chat_id}/context/text")
|
||||
.route(web::post().to(create_chat_context_handler))
|
||||
)
|
||||
}
|
||||
async fn create_chat_handler(
|
||||
path: web::Path<String>,
|
||||
|
@ -105,7 +118,7 @@ async fn create_chat_message_handler(
|
|||
|
||||
let ai_model = ai_model_from_header(&req);
|
||||
let uid = state.user_cache.get_user_uid(&uuid).await?;
|
||||
let message_stream = create_chat_message(
|
||||
let message_stream = create_chat_message_stream(
|
||||
&state.pg_pool,
|
||||
uid,
|
||||
chat_id,
|
||||
|
@ -122,6 +135,20 @@ async fn create_chat_message_handler(
|
|||
)
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn create_chat_context_handler(
|
||||
state: Data<AppState>,
|
||||
payload: Json<CreateTextChatContext>,
|
||||
) -> actix_web::Result<JsonAppResponse<()>> {
|
||||
let params = payload.into_inner();
|
||||
state
|
||||
.ai_client
|
||||
.create_chat_text_context(params)
|
||||
.await
|
||||
.map_err(AppError::from)?;
|
||||
Ok(AppResponse::Ok().into())
|
||||
}
|
||||
|
||||
async fn update_chat_message_handler(
|
||||
state: Data<AppState>,
|
||||
payload: Json<UpdateChatMessageContentParams>,
|
||||
|
@ -155,14 +182,35 @@ async fn create_question_handler(
|
|||
uuid: UserUuid,
|
||||
) -> actix_web::Result<JsonAppResponse<ChatMessage>> {
|
||||
let (_workspace_id, chat_id) = path.into_inner();
|
||||
let params = payload.into_inner();
|
||||
let mut params = payload.into_inner();
|
||||
|
||||
for extract_context in extract_chat_message_metadata(&mut params) {
|
||||
match extract_context {
|
||||
ExtractChatMetadata::Text { text, metadata } => {
|
||||
let context = CreateTextChatContext {
|
||||
chat_id: chat_id.clone(),
|
||||
content_type: "txt".to_string(),
|
||||
text,
|
||||
chunk_size: 2000,
|
||||
chunk_overlap: 20,
|
||||
metadata,
|
||||
};
|
||||
trace!("create chat context: {}", context);
|
||||
state
|
||||
.ai_client
|
||||
.create_chat_text_context(context)
|
||||
.await
|
||||
.map_err(AppError::from)?;
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
let uid = state.user_cache.get_user_uid(&uuid).await?;
|
||||
let resp = create_chat_question(&state.pg_pool, uid, chat_id, params).await?;
|
||||
let resp = create_chat_message(&state.pg_pool, uid, chat_id, params).await?;
|
||||
Ok(AppResponse::Ok().with_data(resp).into())
|
||||
}
|
||||
|
||||
async fn create_answer_handler(
|
||||
async fn save_answer_handler(
|
||||
path: web::Path<(String, String)>,
|
||||
payload: Json<CreateAnswerMessageParams>,
|
||||
state: Data<AppState>,
|
||||
|
@ -176,13 +224,14 @@ async fn create_answer_handler(
|
|||
ChatAuthor::ai(),
|
||||
&chat_id,
|
||||
payload.content,
|
||||
payload.metadata,
|
||||
payload.question_message_id,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(AppResponse::Ok().with_data(message).into())
|
||||
}
|
||||
async fn gen_answer_handler(
|
||||
async fn answer_handler(
|
||||
path: web::Path<(String, String, i64)>,
|
||||
state: Data<AppState>,
|
||||
req: HttpRequest,
|
||||
|
@ -200,7 +249,7 @@ async fn gen_answer_handler(
|
|||
Ok(AppResponse::Ok().with_data(message).into())
|
||||
}
|
||||
|
||||
#[instrument(level = "info", skip_all, err)]
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn answer_stream_handler(
|
||||
path: web::Path<(String, String, i64)>,
|
||||
state: Data<AppState>,
|
||||
|
@ -232,6 +281,38 @@ async fn answer_stream_handler(
|
|||
}
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn answer_stream_v2_handler(
|
||||
path: web::Path<(String, String, i64)>,
|
||||
state: Data<AppState>,
|
||||
req: HttpRequest,
|
||||
) -> actix_web::Result<HttpResponse> {
|
||||
let (_workspace_id, chat_id, question_id) = path.into_inner();
|
||||
let content = chat::chat_ops::select_chat_message_content(&state.pg_pool, question_id).await?;
|
||||
let ai_model = ai_model_from_header(&req);
|
||||
match state
|
||||
.ai_client
|
||||
.stream_question_v2(&chat_id, &content, &ai_model)
|
||||
.await
|
||||
{
|
||||
Ok(answer_stream) => {
|
||||
let new_answer_stream = answer_stream.map_err(AppError::from);
|
||||
Ok(
|
||||
HttpResponse::Ok()
|
||||
.content_type("text/event-stream")
|
||||
.streaming(new_answer_stream),
|
||||
)
|
||||
},
|
||||
Err(err) => Ok(
|
||||
HttpResponse::Ok()
|
||||
.content_type("text/event-stream")
|
||||
.streaming(stream::once(async move {
|
||||
Err(AppError::AIServiceUnavailable(err.to_string()))
|
||||
})),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
async fn get_chat_message_handler(
|
||||
path: web::Path<(String, String)>,
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use actix_web::web::Bytes;
|
||||
use anyhow::anyhow;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use app_error::AppError;
|
||||
use appflowy_ai_client::client::AppFlowyAIClient;
|
||||
|
@ -15,8 +16,9 @@ use database_entity::dto::{
|
|||
CreateChatParams, GetChatMessageParams, RepeatedChatMessage, UpdateChatMessageContentParams,
|
||||
};
|
||||
use futures::stream::Stream;
|
||||
use serde_json::Value;
|
||||
use sqlx::PgPool;
|
||||
use tracing::error;
|
||||
use tracing::{error, info, trace};
|
||||
|
||||
use appflowy_ai_client::dto::AIModel;
|
||||
use validator::Validate;
|
||||
|
@ -65,6 +67,7 @@ pub async fn update_chat_message(
|
|||
ChatAuthor::ai(),
|
||||
¶ms.chat_id,
|
||||
new_answer.content,
|
||||
new_answer.metadata,
|
||||
params.message_id,
|
||||
)
|
||||
.await?;
|
||||
|
@ -84,6 +87,7 @@ pub async fn generate_chat_message_answer(
|
|||
.send_question(chat_id, &content, &ai_model)
|
||||
.await?;
|
||||
|
||||
info!("new_answer: {:?}", new_answer);
|
||||
// Save the answer to the database
|
||||
let mut txn = pg_pool.begin().await?;
|
||||
let message = insert_answer_message_with_transaction(
|
||||
|
@ -91,6 +95,7 @@ pub async fn generate_chat_message_answer(
|
|||
ChatAuthor::ai(),
|
||||
chat_id,
|
||||
new_answer.content,
|
||||
new_answer.metadata.unwrap_or_default(),
|
||||
question_message_id,
|
||||
)
|
||||
.await?;
|
||||
|
@ -109,6 +114,136 @@ pub async fn create_chat_message(
|
|||
uid: i64,
|
||||
chat_id: String,
|
||||
params: CreateChatMessageParams,
|
||||
) -> Result<ChatMessage, AppError> {
|
||||
let params = params.clone();
|
||||
let chat_id = chat_id.clone();
|
||||
let pg_pool = pg_pool.clone();
|
||||
|
||||
let question = insert_question_message(
|
||||
&pg_pool,
|
||||
ChatAuthor::new(uid, ChatAuthorType::Human),
|
||||
&chat_id,
|
||||
params.content.clone(),
|
||||
params.metadata,
|
||||
)
|
||||
.await?;
|
||||
Ok(question)
|
||||
}
|
||||
|
||||
enum ContextType {
|
||||
Unknown,
|
||||
Text,
|
||||
}
|
||||
|
||||
/// Extracts the chat context from the metadata. Currently, we only support text as a context. In
|
||||
/// the future, we will support other types of context.
|
||||
pub(crate) enum ExtractChatMetadata {
|
||||
Text {
|
||||
text: String,
|
||||
metadata: HashMap<String, Value>,
|
||||
},
|
||||
}
|
||||
/// Removes the "content" field from the metadata if the "ty" field is equal to "text".
|
||||
/// The metadata struct is shown below:
|
||||
/// {
|
||||
/// "data": {
|
||||
/// "content": "hello world"
|
||||
/// "size": 122,
|
||||
/// "content_type": "text",
|
||||
/// },
|
||||
/// "id": "id",
|
||||
/// "name": "name"
|
||||
/// }
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `params`: A mutable reference to `CreateChatMessageParams` which contains metadata.
|
||||
///
|
||||
/// # Returns
|
||||
/// - `Option<(String, HashMap<String, serde_json::Value>)>`: A tuple containing the removed content and the updated metadata, otherwise `None`.
|
||||
fn extract_message_metadata(
|
||||
message_metadata: &mut serde_json::Value,
|
||||
) -> Option<ExtractChatMetadata> {
|
||||
trace!("Extracting metadata: {:?}", message_metadata);
|
||||
|
||||
if let Value::Object(message_metadata) = message_metadata {
|
||||
let mut context_type = ContextType::Unknown;
|
||||
if let Some(Value::Object(data)) = message_metadata.get("data") {
|
||||
if let Some(ty) = data.get("content_type").and_then(|v| v.as_str()) {
|
||||
match ty {
|
||||
"text" => context_type = ContextType::Text,
|
||||
_ => context_type = ContextType::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match context_type {
|
||||
ContextType::Unknown => {
|
||||
// do nothing
|
||||
},
|
||||
ContextType::Text => {
|
||||
// remove the "data" field from the context if the "ty" field is equal to "text"
|
||||
let mut text = None;
|
||||
if let Some(Value::Object(ref mut data)) = message_metadata.remove("data") {
|
||||
let content = data
|
||||
.remove("content")
|
||||
.and_then(|value| {
|
||||
if let Value::String(s) = value {
|
||||
Some(s)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let content_size = data
|
||||
.remove("size")
|
||||
.and_then(|value| {
|
||||
if let Value::Number(n) = value {
|
||||
n.as_i64()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap_or(0);
|
||||
|
||||
// If the content is not empty and the content size is equal to the length of the content
|
||||
if !content.is_empty() && content.len() == content_size as usize {
|
||||
text = Some(content);
|
||||
}
|
||||
}
|
||||
|
||||
return text.map(|text| ExtractChatMetadata::Text {
|
||||
text,
|
||||
metadata: message_metadata.clone().into_iter().collect(),
|
||||
});
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
pub(crate) fn extract_chat_message_metadata(
|
||||
params: &mut CreateChatMessageParams,
|
||||
) -> Vec<ExtractChatMetadata> {
|
||||
let mut extract_metadatas = vec![];
|
||||
if let Some(Value::Array(ref mut list)) = params.metadata {
|
||||
trace!("Extracting chat metadata: {:?}", list);
|
||||
for metadata in list {
|
||||
if let Some(extract_context) = extract_message_metadata(metadata) {
|
||||
extract_metadatas.push(extract_context);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extract_metadatas
|
||||
}
|
||||
|
||||
pub async fn create_chat_message_stream(
|
||||
pg_pool: &PgPool,
|
||||
uid: i64,
|
||||
chat_id: String,
|
||||
params: CreateChatMessageParams,
|
||||
ai_client: AppFlowyAIClient,
|
||||
ai_model: AIModel,
|
||||
) -> impl Stream<Item = Result<Bytes, AppError>> {
|
||||
|
@ -121,7 +256,8 @@ pub async fn create_chat_message(
|
|||
&pg_pool,
|
||||
ChatAuthor::new(uid, ChatAuthorType::Human),
|
||||
&chat_id,
|
||||
params.content.clone()
|
||||
params.content.clone(),
|
||||
params.metadata,
|
||||
).await {
|
||||
Ok(question) => question,
|
||||
Err(err) => {
|
||||
|
@ -147,8 +283,8 @@ pub async fn create_chat_message(
|
|||
match params.message_type {
|
||||
ChatMessageType::System => {}
|
||||
ChatMessageType::User => {
|
||||
let content = match ai_client.send_question(&chat_id, ¶ms.content, &ai_model).await {
|
||||
Ok(response) => response.content,
|
||||
let answer = match ai_client.send_question(&chat_id, ¶ms.content, &ai_model).await {
|
||||
Ok(response) => response,
|
||||
Err(err) => {
|
||||
error!("Failed to send question to AI: {}", err);
|
||||
yield Err(AppError::from(err));
|
||||
|
@ -156,7 +292,7 @@ pub async fn create_chat_message(
|
|||
}
|
||||
};
|
||||
|
||||
let answer = match insert_answer_message(&pg_pool, ChatAuthor::ai(), &chat_id, content.clone(),question_id).await {
|
||||
let answer = match insert_answer_message(&pg_pool, ChatAuthor::ai(), &chat_id, answer.content, answer.metadata,question_id).await {
|
||||
Ok(answer) => answer,
|
||||
Err(err) => {
|
||||
error!("Failed to insert answer message: {}", err);
|
||||
|
@ -194,22 +330,3 @@ pub async fn get_chat_messages(
|
|||
txn.commit().await?;
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
pub async fn create_chat_question(
|
||||
pg_pool: &PgPool,
|
||||
uid: i64,
|
||||
chat_id: String,
|
||||
params: CreateChatMessageParams,
|
||||
) -> Result<ChatMessage, AppError> {
|
||||
let params = params.clone();
|
||||
let chat_id = chat_id.clone();
|
||||
let pg_pool = pg_pool.clone();
|
||||
let question = insert_question_message(
|
||||
&pg_pool,
|
||||
ChatAuthor::new(uid, ChatAuthorType::Human),
|
||||
&chat_id,
|
||||
params.content.clone(),
|
||||
)
|
||||
.await?;
|
||||
Ok(question)
|
||||
}
|
||||
|
|
1
tests/ai_test/asset/my_profile.txt
Normal file
1
tests/ai_test/asset/my_profile.txt
Normal file
|
@ -0,0 +1 @@
|
|||
I am Lucas. I live in Singapore for 10 years and work as a software engineer. I'm a fan of AI and enjoy exploring its potential.
|
|
@ -1,6 +1,14 @@
|
|||
use crate::ai_test::util::read_text_from_asset;
|
||||
use appflowy_ai_client::dto::CreateTextChatContext;
|
||||
use assert_json_diff::assert_json_eq;
|
||||
use client_api::entity::QuestionStreamValue;
|
||||
use client_api_test::TestClient;
|
||||
use database_entity::dto::{ChatMessage, CreateChatMessageParams, CreateChatParams, MessageCursor};
|
||||
use database_entity::dto::{
|
||||
ChatMessage, ChatMessageMetadata, ChatMetadataData, CreateChatMessageParams, CreateChatParams,
|
||||
MessageCursor,
|
||||
};
|
||||
use futures_util::StreamExt;
|
||||
use serde_json::json;
|
||||
|
||||
#[tokio::test]
|
||||
async fn create_chat_and_create_messages_test() {
|
||||
|
@ -98,7 +106,7 @@ async fn chat_qa_test() {
|
|||
let chat_id = uuid::Uuid::new_v4().to_string();
|
||||
let params = CreateChatParams {
|
||||
chat_id: chat_id.clone(),
|
||||
name: "my second chat".to_string(),
|
||||
name: "new chat".to_string(),
|
||||
rag_ids: vec![],
|
||||
};
|
||||
|
||||
|
@ -108,19 +116,42 @@ async fn chat_qa_test() {
|
|||
.await
|
||||
.unwrap();
|
||||
|
||||
let params = CreateChatMessageParams::new_user("where is singapore?");
|
||||
let stream = test_client
|
||||
let content = read_text_from_asset("my_profile.txt");
|
||||
let metadata = ChatMessageMetadata {
|
||||
data: ChatMetadataData::new_text(content),
|
||||
id: "123".to_string(),
|
||||
name: "test context".to_string(),
|
||||
source: "user added".to_string(),
|
||||
};
|
||||
|
||||
let params =
|
||||
CreateChatMessageParams::new_user("Where lucas live?").with_metadata(json!(vec![metadata]));
|
||||
let question = test_client
|
||||
.api_client
|
||||
.create_question_answer(&workspace_id, &chat_id, params)
|
||||
.create_question(&workspace_id, &chat_id, params)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let messages: Vec<ChatMessage> = stream.map(|message| message.unwrap()).collect().await;
|
||||
assert_eq!(messages.len(), 2);
|
||||
let answer = test_client
|
||||
.api_client
|
||||
.get_answer(&workspace_id, &chat_id, question.message_id)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(answer.content.contains("Singapore"));
|
||||
assert_json_eq!(
|
||||
answer.meta_data,
|
||||
json!([
|
||||
{
|
||||
"id": "123",
|
||||
"name": "test context",
|
||||
"source": "user added",
|
||||
}
|
||||
])
|
||||
);
|
||||
|
||||
let related_questions = test_client
|
||||
.api_client
|
||||
.get_chat_related_question(&workspace_id, &chat_id, messages[1].message_id)
|
||||
.get_chat_related_question(&workspace_id, &chat_id, question.message_id)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(related_questions.items.len(), 3);
|
||||
|
@ -154,7 +185,7 @@ async fn generate_chat_message_answer_test() {
|
|||
|
||||
let answer = test_client
|
||||
.api_client
|
||||
.generate_answer(&workspace_id, &chat_id, messages[0].message_id)
|
||||
.get_answer(&workspace_id, &chat_id, messages[0].message_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
@ -188,23 +219,88 @@ async fn generate_stream_answer_test() {
|
|||
let params = CreateChatMessageParams::new_user("Teach me how to write a article?");
|
||||
let question = test_client
|
||||
.api_client
|
||||
.save_question(&workspace_id, &chat_id, params)
|
||||
.create_question(&workspace_id, &chat_id, params)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// test v1 api endpoint
|
||||
let mut answer_stream = test_client
|
||||
.api_client
|
||||
.ask_question(&workspace_id, &chat_id, question.message_id)
|
||||
.stream_answer(&workspace_id, &chat_id, question.message_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut answer = String::new();
|
||||
let mut answer_v1 = String::new();
|
||||
while let Some(message) = answer_stream.next().await {
|
||||
let message = message.unwrap();
|
||||
let s = String::from_utf8(message.to_vec()).unwrap();
|
||||
answer.push_str(&s);
|
||||
answer_v1.push_str(&s);
|
||||
}
|
||||
assert!(!answer.is_empty());
|
||||
assert!(!answer_v1.is_empty());
|
||||
|
||||
// test v2 api endpoint
|
||||
let mut answer_stream = test_client
|
||||
.api_client
|
||||
.stream_answer_v2(&workspace_id, &chat_id, question.message_id)
|
||||
.await
|
||||
.unwrap();
|
||||
let mut answer_v2 = String::new();
|
||||
while let Some(value) = answer_stream.next().await {
|
||||
match value.unwrap() {
|
||||
QuestionStreamValue::Answer { value } => {
|
||||
answer_v2.push_str(&value);
|
||||
},
|
||||
QuestionStreamValue::Metadata { .. } => {},
|
||||
}
|
||||
}
|
||||
assert!(!answer_v2.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn create_chat_context_test() {
|
||||
let test_client = TestClient::new_user_without_ws_conn().await;
|
||||
let workspace_id = test_client.workspace_id().await;
|
||||
let chat_id = uuid::Uuid::new_v4().to_string();
|
||||
let params = CreateChatParams {
|
||||
chat_id: chat_id.clone(),
|
||||
name: "context chat".to_string(),
|
||||
rag_ids: vec![],
|
||||
};
|
||||
|
||||
test_client
|
||||
.api_client
|
||||
.create_chat(&workspace_id, params)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let context = CreateTextChatContext {
|
||||
chat_id: chat_id.clone(),
|
||||
content_type: "txt".to_string(),
|
||||
text: "I have lived in the US for five years".to_string(),
|
||||
chunk_size: 1000,
|
||||
chunk_overlap: 20,
|
||||
metadata: Default::default(),
|
||||
};
|
||||
|
||||
test_client
|
||||
.api_client
|
||||
.create_chat_context(&workspace_id, context)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let params = CreateChatMessageParams::new_user("Where I live?");
|
||||
let question = test_client
|
||||
.api_client
|
||||
.create_question(&workspace_id, &chat_id, params)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let answer = test_client
|
||||
.api_client
|
||||
.get_answer(&workspace_id, &chat_id, question.message_id)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(answer.content.contains("US"));
|
||||
println!("answer: {:?}", answer);
|
||||
}
|
||||
|
||||
// #[tokio::test]
|
||||
|
|
|
@ -2,3 +2,4 @@ mod chat_test;
|
|||
mod complete_text;
|
||||
// mod local_ai_test;
|
||||
mod summarize_row;
|
||||
mod util;
|
||||
|
|
9
tests/ai_test/util.rs
Normal file
9
tests/ai_test/util.rs
Normal file
|
@ -0,0 +1,9 @@
|
|||
use std::fs::File;
|
||||
use std::io::Read;
|
||||
|
||||
pub(crate) fn read_text_from_asset(file_name: &str) -> String {
|
||||
let mut file = File::open(format!("./tests/ai_test/asset/{}", file_name)).unwrap();
|
||||
let mut buffer = Vec::new();
|
||||
file.read_to_end(&mut buffer).unwrap();
|
||||
String::from_utf8(buffer).unwrap()
|
||||
}
|
|
@ -96,6 +96,7 @@ async fn chat_message_crud_test(pool: PgPool) {
|
|||
ChatAuthor::new(0, ChatAuthorType::System),
|
||||
&chat_id,
|
||||
format!("message {}", i),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
Loading…
Add table
Reference in a new issue