chore: enable chat with provided context (#713)

* chore: enable chat with provided context

* chore: rename

* chore: update create chat message api endpoint

* chore: use list context

* chore: use list context

* chore: fix test

* chore: update api endpoint

* chore: rename client api function

* chore: rename client api function

* chore: expose entity

* chore: update sqlx files

* chore: update test
This commit is contained in:
Nathan.fooo 2024-08-05 14:06:44 +08:00 committed by GitHub
parent 3b389d7911
commit a371912c61
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 766 additions and 125 deletions

View file

@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO af_chat_messages (chat_id, author, content)\n VALUES ($1, $2, $3)\n RETURNING message_id, created_at\n ",
"query": "\n INSERT INTO af_chat_messages (chat_id, author, content, meta_data)\n VALUES ($1, $2, $3, $4)\n RETURNING message_id, created_at\n ",
"describe": {
"columns": [
{
@ -18,7 +18,8 @@
"Left": [
"Uuid",
"Jsonb",
"Text"
"Text",
"Jsonb"
]
},
"nullable": [
@ -26,5 +27,5 @@
false
]
},
"hash": "6c0a058f5a2a53ad6f89fc7b9af214d7cd5026093bb3f1c94570c87441294977"
"hash": "09ff850490eab213cfa0ad88ece9ce7baa39beabee19754fd993268d29552eb9"
}

View file

@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO af_chat_messages (chat_id, author, content)\n VALUES ($1, $2, $3)\n RETURNING message_id, created_at\n ",
"query": "\n INSERT INTO af_chat_messages (chat_id, author, content, meta_data)\n VALUES ($1, $2, $3, $4)\n RETURNING message_id, created_at\n ",
"describe": {
"columns": [
{
@ -18,7 +18,8 @@
"Left": [
"Uuid",
"Jsonb",
"Text"
"Text",
"Jsonb"
]
},
"nullable": [
@ -26,5 +27,5 @@
false
]
},
"hash": "95b4d7508569cac38c78d21a0a471772d3703e5678ee7ca0cd32d60f5343be91"
"hash": "878399ad7c572c8c7c469229af15eadc7421ecc34844d05526196bc3a4147f6b"
}

View file

@ -1,16 +1,17 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE af_chat_messages\n SET content = $2,\n author = $3,\n created_at = CURRENT_TIMESTAMP\n WHERE message_id = $1\n ",
"query": "\n UPDATE af_chat_messages\n SET content = $2,\n author = $3,\n created_at = CURRENT_TIMESTAMP,\n meta_data = $4\n WHERE message_id = $1\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Int8",
"Text",
"Jsonb",
"Jsonb"
]
},
"nullable": []
},
"hash": "8b4af677dd62367e68cdbf9a462e183c702cf98113cacf56a6acb6d130d2a79a"
"hash": "da1434fe116cbb48bc5aac0b6905dd748f096bf78d3cdcfea3a576b4aaeba5fc"
}

1
Cargo.lock generated
View file

@ -530,6 +530,7 @@ dependencies = [
"appflowy-ai-client",
"bytes",
"futures",
"pin-project",
"reqwest 0.12.5",
"serde",
"serde_json",

View file

@ -15,6 +15,7 @@ tracing = { version = "0.1", optional = true }
serde_repr = { version = "0.1", optional = true }
futures = "0.3.30"
bytes = "1.6.0"
pin-project = "1.1.5"
[dev-dependencies]
appflowy-ai-client = { path = ".", features = ["dto", "client-api"] }

View file

@ -1,20 +1,27 @@
use crate::dto::{
AIModel, ChatAnswer, ChatQuestion, CompleteTextResponse, CompletionType, Document,
EmbeddingRequest, EmbeddingResponse, LocalAIConfig, MessageData, RepeatedLocalAIPackage,
RepeatedRelatedQuestion, SearchDocumentsRequest, SummarizeRowResponse, TranslateRowData,
TranslateRowResponse,
AIModel, ChatAnswer, ChatQuestion, CompleteTextResponse, CompletionType, CreateTextChatContext,
Document, EmbeddingRequest, EmbeddingResponse, LocalAIConfig, MessageData,
RepeatedLocalAIPackage, RepeatedRelatedQuestion, SearchDocumentsRequest, SummarizeRowResponse,
TranslateRowData, TranslateRowResponse,
};
use crate::error::AIError;
use bytes::Bytes;
use futures::{Stream, StreamExt};
use futures::{ready, Stream, StreamExt, TryStreamExt};
use reqwest;
use reqwest::{Method, RequestBuilder, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::{json, Map, Value};
use serde_json::{json, Map, StreamDeserializer, Value};
use anyhow::anyhow;
use pin_project::pin_project;
use serde::de::DeserializeOwned;
use serde_json::de::SliceRead;
use std::borrow::Cow;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use tracing::{info, trace};
const AI_MODEL_HEADER_KEY: &str = "ai-model";
@ -174,6 +181,20 @@ impl AppFlowyAIClient {
.into_data()
}
pub async fn create_chat_text_context(
&self,
context: CreateTextChatContext,
) -> Result<(), AIError> {
let url = format!("{}/chat/context/text", self.url);
let resp = self
.http_client(Method::POST, &url)?
.json(&context)
.send()
.await?;
let _ = AIResponse::<()>::from_response(resp).await?;
Ok(())
}
pub async fn send_question(
&self,
chat_id: &str,
@ -220,6 +241,28 @@ impl AppFlowyAIClient {
AIResponse::<()>::stream_response(resp).await
}
pub async fn stream_question_v2(
&self,
chat_id: &str,
content: &str,
model: &AIModel,
) -> Result<impl Stream<Item = Result<Bytes, AIError>>, AIError> {
let json = ChatQuestion {
chat_id: chat_id.to_string(),
data: MessageData {
content: content.to_string(),
},
};
let url = format!("{}/v2/chat/message/stream", self.url);
let resp = self
.http_client(Method::POST, &url)?
.header(AI_MODEL_HEADER_KEY, model.to_str())
.json(&json)
.send()
.await?;
AIResponse::<()>::stream_response(resp).await
}
pub async fn get_related_question(
&self,
chat_id: &str,
@ -307,6 +350,19 @@ where
.map(|item| item.map_err(|err| AIError::Internal(err.into())));
Ok(stream)
}
pub async fn json_stream_response(
resp: reqwest::Response,
) -> Result<impl Stream<Item = Result<T, AIError>>, AIError> {
let status_code = resp.status();
if !status_code.is_success() {
let body = resp.text().await?;
return Err(AIError::Internal(anyhow!(body)));
}
let stream = resp.bytes_stream().map_err(AIError::from);
Ok(JsonStream::new(stream))
}
}
impl From<reqwest::Error> for AIError {
fn from(error: reqwest::Error) -> Self {
@ -324,3 +380,62 @@ impl From<reqwest::Error> for AIError {
AIError::Internal(error.into())
}
}
#[pin_project]
pub struct JsonStream<T> {
stream: Pin<Box<dyn Stream<Item = Result<Bytes, AIError>> + Send>>,
buffer: Vec<u8>,
_marker: PhantomData<T>,
}
impl<T> JsonStream<T> {
pub fn new<S>(stream: S) -> Self
where
S: Stream<Item = Result<Bytes, AIError>> + Send + 'static,
{
JsonStream {
stream: Box::pin(stream),
buffer: Vec::new(),
_marker: PhantomData,
}
}
}
impl<T> Stream for JsonStream<T>
where
T: DeserializeOwned,
{
type Item = Result<T, AIError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
match ready!(this.stream.as_mut().poll_next(cx)) {
Some(Ok(bytes)) => {
this.buffer.extend_from_slice(&bytes);
let de = StreamDeserializer::new(SliceRead::new(this.buffer));
let mut iter = de.into_iter();
if let Some(result) = iter.next() {
match result {
Ok(value) => {
let remaining = iter.byte_offset();
this.buffer.drain(0..remaining);
Poll::Ready(Some(Ok(value)))
},
Err(err) => {
if err.is_eof() {
Poll::Pending
} else {
Poll::Ready(Some(Err(AIError::Internal(err.into()))))
}
},
}
} else {
Poll::Pending
}
},
Some(Err(err)) => Poll::Ready(Some(Err(err))),
None => Poll::Ready(None),
}
}
}

View file

@ -4,6 +4,8 @@ use std::collections::HashMap;
use std::fmt::{Display, Formatter};
use std::str::FromStr;
pub const STEAM_METADATA_KEY: &str = "0";
pub const STEAM_ANSWER_KEY: &str = "1";
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SummarizeRowResponse {
pub text: String,
@ -23,6 +25,8 @@ pub struct MessageData {
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ChatAnswer {
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
@ -266,3 +270,26 @@ pub struct LocalAIConfig {
pub models: Vec<LLMModel>,
pub plugin: AppFlowyOfflineAI,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CreateTextChatContext {
pub chat_id: String,
/// Only support "txt" and "md" for now
pub content_type: String,
pub text: String,
pub chunk_size: i32,
pub chunk_overlap: i32,
pub metadata: HashMap<String, serde_json::Value>,
}
impl Display for CreateTextChatContext {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!(
"Create Chat context: {{ chat_id: {}, content_type: {}, content size: {}, metadata: {:?} }}",
self.chat_id,
self.content_type,
self.text.len(),
self.metadata
))
}
}

View file

@ -0,0 +1,23 @@
use crate::appflowy_ai_client;
use appflowy_ai_client::dto::{AIModel, CreateTextChatContext};
#[tokio::test]
async fn create_chat_context_test() {
let client = appflowy_ai_client();
let chat_id = uuid::Uuid::new_v4().to_string();
let context = CreateTextChatContext {
chat_id: chat_id.clone(),
content_type: "txt".to_string(),
text: "I have lived in the US for five years".to_string(),
chunk_size: 1000,
chunk_overlap: 20,
metadata: Default::default(),
};
client.create_chat_text_context(context).await.unwrap();
let resp = client
.send_question(&chat_id, "Where I live?", &AIModel::GPT35)
.await
.unwrap();
// response will be something like:
// Based on the context you provided, you have lived in the US for five years. Therefore, it is likely that you currently live in the US
assert!(!resp.content.is_empty());
}

View file

@ -1,3 +1,4 @@
mod completion_test;
mod context_test;
mod embedding_test;
mod qa_test;

View file

@ -1,6 +1,7 @@
use crate::appflowy_ai_client;
use appflowy_ai_client::dto::AIModel;
use appflowy_ai_client::client::JsonStream;
use appflowy_ai_client::dto::{AIModel, STEAM_ANSWER_KEY};
use futures::stream::StreamExt;
#[tokio::test]
@ -50,19 +51,23 @@ async fn stream_test() {
client.health_check().await.unwrap();
let chat_id = uuid::Uuid::new_v4().to_string();
let stream = client
.stream_question(&chat_id, "I feel hungry", &AIModel::GPT35)
.stream_question_v2(&chat_id, "I feel hungry", &AIModel::GPT35)
.await
.unwrap();
let json_stream = JsonStream::<serde_json::Value>::new(stream);
let stream = stream.map(|item| {
item.map(|bytes| {
String::from_utf8(bytes.to_vec())
.map(|s| s.replace('\n', ""))
.unwrap()
let messages: Vec<String> = json_stream
.filter_map(|item| async {
match item {
Ok(value) => value
.get(STEAM_ANSWER_KEY)
.and_then(|s| s.as_str().map(ToString::to_string)),
Err(_) => None,
}
})
});
.collect()
.await;
let messages: Vec<String> = stream.map(|message| message.unwrap()).collect().await;
println!("final answer: {}", messages.join(""));
}

View file

@ -8,7 +8,7 @@ use std::sync::Arc;
use uuid::Uuid;
use client_api::error::ErrorCode;
use console_error_panic_hook;
use database_entity::dto::QueryCollab;
use wasm_bindgen::prelude::*;

View file

@ -5,10 +5,17 @@ use client_api_entity::{
ChatMessage, CreateAnswerMessageParams, CreateChatMessageParams, CreateChatParams, MessageCursor,
RepeatedChatMessage, UpdateChatMessageContentParams,
};
use futures_core::Stream;
use futures_core::{ready, Stream};
use pin_project::pin_project;
use reqwest::Method;
use shared_entity::dto::ai_dto::RepeatedRelatedQuestion;
use serde_json::Value;
use shared_entity::dto::ai_dto::{
CreateTextChatContext, RepeatedRelatedQuestion, STEAM_ANSWER_KEY, STEAM_METADATA_KEY,
};
use shared_entity::response::{AppResponse, AppResponseError};
use std::pin::Pin;
use std::task::{Context, Poll};
use tracing::error;
impl Client {
/// Create a new chat
@ -45,7 +52,7 @@ impl Client {
}
/// Save a question message to a chat
pub async fn save_question(
pub async fn create_question(
&self,
workspace_id: &str,
chat_id: &str,
@ -91,14 +98,14 @@ impl Client {
}
/// Ask AI with a question for given question's message_id
pub async fn ask_question(
pub async fn stream_answer(
&self,
workspace_id: &str,
chat_id: &str,
message_id: i64,
question_message_id: i64,
) -> Result<impl Stream<Item = Result<Bytes, AppResponseError>>, AppResponseError> {
let url = format!(
"{}/api/chat/{workspace_id}/{chat_id}/{message_id}/answer/stream",
"{}/api/chat/{workspace_id}/{chat_id}/{question_message_id}/answer/stream",
self.base_url
);
let resp = self
@ -110,16 +117,36 @@ impl Client {
AppResponse::<()>::answer_response_stream(resp).await
}
/// Generate an answer for given question's message_id. The same as ask_question but return ChatMessage
/// instead of stream of Bytes
pub async fn generate_answer(
pub async fn stream_answer_v2(
&self,
workspace_id: &str,
chat_id: &str,
message_id: i64,
question_message_id: i64,
) -> Result<QuestionStream, AppResponseError> {
let url = format!(
"{}/api/chat/{workspace_id}/{chat_id}/{question_message_id}/v2/answer/stream",
self.base_url
);
let resp = self
.http_client_with_auth(Method::GET, &url)
.await?
.send()
.await?;
log_request_id(&resp);
let stream = AppResponse::<serde_json::Value>::json_response_stream(resp).await?;
Ok(QuestionStream::new(stream))
}
/// Generate an answer for given question's message_id. The same as ask_question but return ChatMessage
/// instead of stream of Bytes
pub async fn get_answer(
&self,
workspace_id: &str,
chat_id: &str,
question_message_id: i64,
) -> Result<ChatMessage, AppResponseError> {
let url = format!(
"{}/api/chat/{workspace_id}/{chat_id}/{message_id}/answer",
"{}/api/chat/{workspace_id}/{chat_id}/{question_message_id}/answer",
self.base_url
);
let resp = self
@ -231,4 +258,79 @@ impl Client {
log_request_id(&resp);
AppResponse::<ChatMessage>::json_response_stream(resp).await
}
pub async fn create_chat_context(
&self,
workspace_id: &str,
params: CreateTextChatContext,
) -> Result<(), AppResponseError> {
let url = format!(
"{}/api/chat/{workspace_id}/{}/context/text",
self.base_url, params.chat_id
);
let resp = self
.http_client_with_auth(Method::POST, &url)
.await?
.json(&params)
.send()
.await?;
log_request_id(&resp);
AppResponse::<()>::from_response(resp).await?.into_error()
}
}
#[pin_project]
pub struct QuestionStream {
stream: Pin<Box<dyn Stream<Item = Result<serde_json::Value, AppResponseError>> + Send>>,
buffer: Vec<u8>,
}
impl QuestionStream {
pub fn new<S>(stream: S) -> Self
where
S: Stream<Item = Result<serde_json::Value, AppResponseError>> + Send + 'static,
{
QuestionStream {
stream: Box::pin(stream),
buffer: Vec::new(),
}
}
}
pub enum QuestionStreamValue {
Answer { value: String },
Metadata { value: serde_json::Value },
}
impl Stream for QuestionStream {
type Item = Result<QuestionStreamValue, AppResponseError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
return match ready!(this.stream.as_mut().poll_next(cx)) {
Some(Ok(value)) => match value {
Value::Object(mut value) => {
if let Some(metadata) = value.remove(STEAM_METADATA_KEY) {
return Poll::Ready(Some(Ok(QuestionStreamValue::Metadata { value: metadata })));
}
if let Some(answer) = value
.remove(STEAM_ANSWER_KEY)
.and_then(|s| s.as_str().map(ToString::to_string))
{
return Poll::Ready(Some(Ok(QuestionStreamValue::Answer { value: answer })));
}
error!("Invalid streaming value: {:?}", value);
Poll::Ready(None)
},
_ => {
error!("Unexpected JSON value type: {:?}", value);
Poll::Ready(None)
},
},
Some(Err(err)) => Poll::Ready(Some(Err(err))),
None => Poll::Ready(None),
};
}
}

View file

@ -26,6 +26,7 @@ pub use wasm::*;
#[cfg(not(target_arch = "wasm32"))]
mod http_chat;
mod http_search;
mod http_settings;
pub mod ws;
@ -37,6 +38,8 @@ pub mod error {
// Export all dto entities that will be used in the frontend application
pub mod entity {
#[cfg(not(target_arch = "wasm32"))]
pub use crate::http_chat::{QuestionStream, QuestionStreamValue};
pub use client_api_entity::*;
}

View file

@ -648,6 +648,35 @@ pub struct CreateChatMessageParams {
#[validate(custom = "validate_not_empty_str")]
pub content: String,
pub message_type: ChatMessageType,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessageMetadata {
pub data: ChatMetadataData,
pub id: String,
pub name: String,
pub source: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMetadataData {
/// Don't rename this field, it's used [ops::extract_chat_message_metadata]
content: String,
pub content_type: String,
size: i64,
}
impl ChatMetadataData {
pub fn new_text(content: String) -> Self {
let size = content.len();
Self {
content,
content_type: "text".to_string(),
size: size as i64,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -678,6 +707,7 @@ impl CreateChatMessageParams {
Self {
content: content.to_string(),
message_type: ChatMessageType::System,
metadata: None,
}
}
@ -685,8 +715,14 @@ impl CreateChatMessageParams {
Self {
content: content.to_string(),
message_type: ChatMessageType::User,
metadata: None,
}
}
pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
self.metadata = Some(metadata);
self
}
}
#[derive(Debug, Clone, Validate, Serialize, Deserialize)]
pub struct GetChatMessageParams {
@ -828,6 +864,9 @@ pub struct CreateAnswerMessageParams {
#[validate(custom = "validate_not_empty_str")]
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
pub question_message_id: i64,
}

View file

@ -11,6 +11,7 @@ use database_entity::dto::{
use serde_json::json;
use sqlx::postgres::PgArguments;
use sqlx::types::JsonValue;
use sqlx::{Arguments, Executor, PgPool, Postgres, Transaction};
use std::ops::DerefMut;
use std::str::FromStr;
@ -135,6 +136,7 @@ pub async fn insert_answer_message_with_transaction(
author: ChatAuthor,
chat_id: &str,
content: String,
metadata: serde_json::Value,
question_message_id: i64,
) -> Result<ChatMessage, AppError> {
let chat_id = Uuid::from_str(chat_id)?;
@ -156,12 +158,14 @@ pub async fn insert_answer_message_with_transaction(
UPDATE af_chat_messages
SET content = $2,
author = $3,
created_at = CURRENT_TIMESTAMP
created_at = CURRENT_TIMESTAMP,
meta_data = $4
WHERE message_id = $1
"#,
reply_id,
&content,
json!(author),
metadata,
)
.execute(transaction.deref_mut())
.await
@ -193,13 +197,14 @@ pub async fn insert_answer_message_with_transaction(
// Insert a new chat message
let row = sqlx::query!(
r#"
INSERT INTO af_chat_messages (chat_id, author, content)
VALUES ($1, $2, $3)
INSERT INTO af_chat_messages (chat_id, author, content, meta_data)
VALUES ($1, $2, $3, $4)
RETURNING message_id, created_at
"#,
chat_id,
json!(author),
&content,
&metadata,
)
.fetch_one(transaction.deref_mut())
.await
@ -224,7 +229,7 @@ pub async fn insert_answer_message_with_transaction(
message_id: row.message_id,
content,
created_at: row.created_at,
meta_data: Default::default(),
meta_data: metadata,
reply_message_id: None,
};
@ -237,12 +242,19 @@ pub async fn insert_answer_message(
author: ChatAuthor,
chat_id: &str,
content: String,
metadata: Option<serde_json::Value>,
question_message_id: i64,
) -> Result<ChatMessage, AppError> {
let mut txn = pg_pool.begin().await?;
let chat_message =
insert_answer_message_with_transaction(&mut txn, author, chat_id, content, question_message_id)
.await?;
let chat_message = insert_answer_message_with_transaction(
&mut txn,
author,
chat_id,
content,
metadata.unwrap_or_default(),
question_message_id,
)
.await?;
txn.commit().await.map_err(|err| {
AppError::Internal(anyhow!(
"Failed to commit transaction to insert answer message: {}",
@ -257,17 +269,20 @@ pub async fn insert_question_message<'a, E: Executor<'a, Database = Postgres>>(
author: ChatAuthor,
chat_id: &str,
content: String,
metadata: Option<JsonValue>,
) -> Result<ChatMessage, AppError> {
let metadata = metadata.unwrap_or_else(|| json!({}));
let chat_id = Uuid::from_str(chat_id)?;
let row = sqlx::query!(
r#"
INSERT INTO af_chat_messages (chat_id, author, content)
VALUES ($1, $2, $3)
INSERT INTO af_chat_messages (chat_id, author, content, meta_data)
VALUES ($1, $2, $3, $4)
RETURNING message_id, created_at
"#,
chat_id,
json!(author),
&content,
&metadata,
)
.fetch_one(executor)
.await
@ -278,7 +293,7 @@ pub async fn insert_question_message<'a, E: Executor<'a, Database = Postgres>>(
message_id: row.message_id,
content,
created_at: row.created_at,
meta_data: Default::default(),
meta_data: metadata,
reply_message_id: None,
};
Ok(chat_message)

View file

@ -89,32 +89,32 @@ where
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
loop {
match ready!(this.stream.as_mut().poll_next(cx)) {
Some(Ok(bytes)) => {
this.buffer.extend_from_slice(&bytes);
let de = StreamDeserializer::new(SliceRead::new(this.buffer));
let mut iter = de.into_iter();
if let Some(result) = iter.next() {
match result {
Ok(value) => {
let remaining = iter.byte_offset();
this.buffer.drain(0..remaining);
return Poll::Ready(Some(Ok(value)));
},
Err(err) => {
if err.is_eof() {
continue;
} else {
return Poll::Ready(Some(Err(AppResponseError::from(err))));
}
},
}
}
},
Some(Err(err)) => return Poll::Ready(Some(Err(err))),
None => return Poll::Ready(None),
}
match ready!(this.stream.as_mut().poll_next(cx)) {
Some(Ok(bytes)) => {
this.buffer.extend_from_slice(&bytes);
let de = StreamDeserializer::new(SliceRead::new(this.buffer));
let mut iter = de.into_iter();
if let Some(result) = iter.next() {
return match result {
Ok(value) => {
let remaining = iter.byte_offset();
this.buffer.drain(0..remaining);
Poll::Ready(Some(Ok(value)))
},
Err(err) => {
if err.is_eof() {
Poll::Pending
} else {
Poll::Ready(Some(Err(AppResponseError::from(err))))
}
},
};
} else {
Poll::Pending
}
},
Some(Err(err)) => Poll::Ready(Some(Err(err))),
None => Poll::Ready(None),
}
}
}

View file

@ -1,13 +1,14 @@
use crate::biz::chat::ops::{
create_chat, create_chat_message, create_chat_question, delete_chat,
generate_chat_message_answer, get_chat_messages, update_chat_message,
create_chat, create_chat_message, create_chat_message_stream, delete_chat,
extract_chat_message_metadata, generate_chat_message_answer, get_chat_messages,
update_chat_message, ExtractChatMetadata,
};
use crate::state::AppState;
use actix_web::web::{Data, Json};
use actix_web::{web, HttpRequest, HttpResponse, Scope};
use app_error::AppError;
use appflowy_ai_client::dto::RepeatedRelatedQuestion;
use appflowy_ai_client::dto::{CreateTextChatContext, RepeatedRelatedQuestion};
use authentication::jwt::UserUuid;
use bytes::Bytes;
use database_entity::dto::{
@ -52,20 +53,32 @@ pub fn chat_scope() -> Scope {
.route(web::put().to(update_chat_message_handler)),
)
.service(
// Create a question for given chat
// Creating a [ChatMessage] for given content.
// When client asks a question, it will use this API to create a chat message
web::resource("/{chat_id}/message/question").route(web::post().to(create_question_handler)),
)
// create an answer for given chat
.service(web::resource("/{chat_id}/message/answer").route(web::post().to(create_answer_handler)))
// Writing the final answer for a given chat.
// After the streaming is finished, the client will use this API to save the message to disk.
.service(web::resource("/{chat_id}/message/answer").route(web::post().to(save_answer_handler)))
.service(
// Generate answer for given question.
web::resource("/{chat_id}/{message_id}/answer").route(web::get().to(gen_answer_handler)),
// Use AI to generate a response for a specified message ID.
// To generate an answer for a given question, use "/answer/stream" to receive the answer in a stream.
web::resource("/{chat_id}/{message_id}/answer").route(web::get().to(answer_handler)),
)
// Stream the answer for given question.
// Use AI to generate a response for a specified message ID. This response will be return as a stream.
.service(
web::resource("/{chat_id}/{message_id}/answer/stream")
.route(web::get().to(answer_stream_handler)),
)
.service(
web::resource("/{chat_id}/{message_id}/v2/answer/stream")
.route(web::get().to(answer_stream_v2_handler)),
)
.service(
// Create chat context for a given chat.
web::resource("/{chat_id}/context/text")
.route(web::post().to(create_chat_context_handler))
)
}
async fn create_chat_handler(
path: web::Path<String>,
@ -105,7 +118,7 @@ async fn create_chat_message_handler(
let ai_model = ai_model_from_header(&req);
let uid = state.user_cache.get_user_uid(&uuid).await?;
let message_stream = create_chat_message(
let message_stream = create_chat_message_stream(
&state.pg_pool,
uid,
chat_id,
@ -122,6 +135,20 @@ async fn create_chat_message_handler(
)
}
#[instrument(level = "debug", skip_all, err)]
async fn create_chat_context_handler(
state: Data<AppState>,
payload: Json<CreateTextChatContext>,
) -> actix_web::Result<JsonAppResponse<()>> {
let params = payload.into_inner();
state
.ai_client
.create_chat_text_context(params)
.await
.map_err(AppError::from)?;
Ok(AppResponse::Ok().into())
}
async fn update_chat_message_handler(
state: Data<AppState>,
payload: Json<UpdateChatMessageContentParams>,
@ -155,14 +182,35 @@ async fn create_question_handler(
uuid: UserUuid,
) -> actix_web::Result<JsonAppResponse<ChatMessage>> {
let (_workspace_id, chat_id) = path.into_inner();
let params = payload.into_inner();
let mut params = payload.into_inner();
for extract_context in extract_chat_message_metadata(&mut params) {
match extract_context {
ExtractChatMetadata::Text { text, metadata } => {
let context = CreateTextChatContext {
chat_id: chat_id.clone(),
content_type: "txt".to_string(),
text,
chunk_size: 2000,
chunk_overlap: 20,
metadata,
};
trace!("create chat context: {}", context);
state
.ai_client
.create_chat_text_context(context)
.await
.map_err(AppError::from)?;
},
}
}
let uid = state.user_cache.get_user_uid(&uuid).await?;
let resp = create_chat_question(&state.pg_pool, uid, chat_id, params).await?;
let resp = create_chat_message(&state.pg_pool, uid, chat_id, params).await?;
Ok(AppResponse::Ok().with_data(resp).into())
}
async fn create_answer_handler(
async fn save_answer_handler(
path: web::Path<(String, String)>,
payload: Json<CreateAnswerMessageParams>,
state: Data<AppState>,
@ -176,13 +224,14 @@ async fn create_answer_handler(
ChatAuthor::ai(),
&chat_id,
payload.content,
payload.metadata,
payload.question_message_id,
)
.await?;
Ok(AppResponse::Ok().with_data(message).into())
}
async fn gen_answer_handler(
async fn answer_handler(
path: web::Path<(String, String, i64)>,
state: Data<AppState>,
req: HttpRequest,
@ -200,7 +249,7 @@ async fn gen_answer_handler(
Ok(AppResponse::Ok().with_data(message).into())
}
#[instrument(level = "info", skip_all, err)]
#[instrument(level = "debug", skip_all, err)]
async fn answer_stream_handler(
path: web::Path<(String, String, i64)>,
state: Data<AppState>,
@ -232,6 +281,38 @@ async fn answer_stream_handler(
}
}
#[instrument(level = "debug", skip_all, err)]
async fn answer_stream_v2_handler(
path: web::Path<(String, String, i64)>,
state: Data<AppState>,
req: HttpRequest,
) -> actix_web::Result<HttpResponse> {
let (_workspace_id, chat_id, question_id) = path.into_inner();
let content = chat::chat_ops::select_chat_message_content(&state.pg_pool, question_id).await?;
let ai_model = ai_model_from_header(&req);
match state
.ai_client
.stream_question_v2(&chat_id, &content, &ai_model)
.await
{
Ok(answer_stream) => {
let new_answer_stream = answer_stream.map_err(AppError::from);
Ok(
HttpResponse::Ok()
.content_type("text/event-stream")
.streaming(new_answer_stream),
)
},
Err(err) => Ok(
HttpResponse::Ok()
.content_type("text/event-stream")
.streaming(stream::once(async move {
Err(AppError::AIServiceUnavailable(err.to_string()))
})),
),
}
}
#[instrument(level = "debug", skip_all, err)]
async fn get_chat_message_handler(
path: web::Path<(String, String)>,

View file

@ -1,5 +1,6 @@
use actix_web::web::Bytes;
use anyhow::anyhow;
use std::collections::HashMap;
use app_error::AppError;
use appflowy_ai_client::client::AppFlowyAIClient;
@ -15,8 +16,9 @@ use database_entity::dto::{
CreateChatParams, GetChatMessageParams, RepeatedChatMessage, UpdateChatMessageContentParams,
};
use futures::stream::Stream;
use serde_json::Value;
use sqlx::PgPool;
use tracing::error;
use tracing::{error, info, trace};
use appflowy_ai_client::dto::AIModel;
use validator::Validate;
@ -65,6 +67,7 @@ pub async fn update_chat_message(
ChatAuthor::ai(),
&params.chat_id,
new_answer.content,
new_answer.metadata,
params.message_id,
)
.await?;
@ -84,6 +87,7 @@ pub async fn generate_chat_message_answer(
.send_question(chat_id, &content, &ai_model)
.await?;
info!("new_answer: {:?}", new_answer);
// Save the answer to the database
let mut txn = pg_pool.begin().await?;
let message = insert_answer_message_with_transaction(
@ -91,6 +95,7 @@ pub async fn generate_chat_message_answer(
ChatAuthor::ai(),
chat_id,
new_answer.content,
new_answer.metadata.unwrap_or_default(),
question_message_id,
)
.await?;
@ -109,6 +114,136 @@ pub async fn create_chat_message(
uid: i64,
chat_id: String,
params: CreateChatMessageParams,
) -> Result<ChatMessage, AppError> {
let params = params.clone();
let chat_id = chat_id.clone();
let pg_pool = pg_pool.clone();
let question = insert_question_message(
&pg_pool,
ChatAuthor::new(uid, ChatAuthorType::Human),
&chat_id,
params.content.clone(),
params.metadata,
)
.await?;
Ok(question)
}
enum ContextType {
Unknown,
Text,
}
/// Extracts the chat context from the metadata. Currently, we only support text as a context. In
/// the future, we will support other types of context.
pub(crate) enum ExtractChatMetadata {
Text {
text: String,
metadata: HashMap<String, Value>,
},
}
/// Removes the "content" field from the metadata if the "ty" field is equal to "text".
/// The metadata struct is shown below:
/// {
/// "data": {
/// "content": "hello world"
/// "size": 122,
/// "content_type": "text",
/// },
/// "id": "id",
/// "name": "name"
/// }
///
/// # Parameters
/// - `params`: A mutable reference to `CreateChatMessageParams` which contains metadata.
///
/// # Returns
/// - `Option<(String, HashMap<String, serde_json::Value>)>`: A tuple containing the removed content and the updated metadata, otherwise `None`.
fn extract_message_metadata(
message_metadata: &mut serde_json::Value,
) -> Option<ExtractChatMetadata> {
trace!("Extracting metadata: {:?}", message_metadata);
if let Value::Object(message_metadata) = message_metadata {
let mut context_type = ContextType::Unknown;
if let Some(Value::Object(data)) = message_metadata.get("data") {
if let Some(ty) = data.get("content_type").and_then(|v| v.as_str()) {
match ty {
"text" => context_type = ContextType::Text,
_ => context_type = ContextType::Unknown,
}
}
}
match context_type {
ContextType::Unknown => {
// do nothing
},
ContextType::Text => {
// remove the "data" field from the context if the "ty" field is equal to "text"
let mut text = None;
if let Some(Value::Object(ref mut data)) = message_metadata.remove("data") {
let content = data
.remove("content")
.and_then(|value| {
if let Value::String(s) = value {
Some(s)
} else {
None
}
})
.unwrap_or_default();
let content_size = data
.remove("size")
.and_then(|value| {
if let Value::Number(n) = value {
n.as_i64()
} else {
None
}
})
.unwrap_or(0);
// If the content is not empty and the content size is equal to the length of the content
if !content.is_empty() && content.len() == content_size as usize {
text = Some(content);
}
}
return text.map(|text| ExtractChatMetadata::Text {
text,
metadata: message_metadata.clone().into_iter().collect(),
});
},
}
}
None
}
pub(crate) fn extract_chat_message_metadata(
params: &mut CreateChatMessageParams,
) -> Vec<ExtractChatMetadata> {
let mut extract_metadatas = vec![];
if let Some(Value::Array(ref mut list)) = params.metadata {
trace!("Extracting chat metadata: {:?}", list);
for metadata in list {
if let Some(extract_context) = extract_message_metadata(metadata) {
extract_metadatas.push(extract_context);
}
}
}
extract_metadatas
}
pub async fn create_chat_message_stream(
pg_pool: &PgPool,
uid: i64,
chat_id: String,
params: CreateChatMessageParams,
ai_client: AppFlowyAIClient,
ai_model: AIModel,
) -> impl Stream<Item = Result<Bytes, AppError>> {
@ -121,7 +256,8 @@ pub async fn create_chat_message(
&pg_pool,
ChatAuthor::new(uid, ChatAuthorType::Human),
&chat_id,
params.content.clone()
params.content.clone(),
params.metadata,
).await {
Ok(question) => question,
Err(err) => {
@ -147,8 +283,8 @@ pub async fn create_chat_message(
match params.message_type {
ChatMessageType::System => {}
ChatMessageType::User => {
let content = match ai_client.send_question(&chat_id, &params.content, &ai_model).await {
Ok(response) => response.content,
let answer = match ai_client.send_question(&chat_id, &params.content, &ai_model).await {
Ok(response) => response,
Err(err) => {
error!("Failed to send question to AI: {}", err);
yield Err(AppError::from(err));
@ -156,7 +292,7 @@ pub async fn create_chat_message(
}
};
let answer = match insert_answer_message(&pg_pool, ChatAuthor::ai(), &chat_id, content.clone(),question_id).await {
let answer = match insert_answer_message(&pg_pool, ChatAuthor::ai(), &chat_id, answer.content, answer.metadata,question_id).await {
Ok(answer) => answer,
Err(err) => {
error!("Failed to insert answer message: {}", err);
@ -194,22 +330,3 @@ pub async fn get_chat_messages(
txn.commit().await?;
Ok(messages)
}
pub async fn create_chat_question(
pg_pool: &PgPool,
uid: i64,
chat_id: String,
params: CreateChatMessageParams,
) -> Result<ChatMessage, AppError> {
let params = params.clone();
let chat_id = chat_id.clone();
let pg_pool = pg_pool.clone();
let question = insert_question_message(
&pg_pool,
ChatAuthor::new(uid, ChatAuthorType::Human),
&chat_id,
params.content.clone(),
)
.await?;
Ok(question)
}

View file

@ -0,0 +1 @@
I am Lucas. I live in Singapore for 10 years and work as a software engineer. I'm a fan of AI and enjoy exploring its potential.

View file

@ -1,6 +1,14 @@
use crate::ai_test::util::read_text_from_asset;
use appflowy_ai_client::dto::CreateTextChatContext;
use assert_json_diff::assert_json_eq;
use client_api::entity::QuestionStreamValue;
use client_api_test::TestClient;
use database_entity::dto::{ChatMessage, CreateChatMessageParams, CreateChatParams, MessageCursor};
use database_entity::dto::{
ChatMessage, ChatMessageMetadata, ChatMetadataData, CreateChatMessageParams, CreateChatParams,
MessageCursor,
};
use futures_util::StreamExt;
use serde_json::json;
#[tokio::test]
async fn create_chat_and_create_messages_test() {
@ -98,7 +106,7 @@ async fn chat_qa_test() {
let chat_id = uuid::Uuid::new_v4().to_string();
let params = CreateChatParams {
chat_id: chat_id.clone(),
name: "my second chat".to_string(),
name: "new chat".to_string(),
rag_ids: vec![],
};
@ -108,19 +116,42 @@ async fn chat_qa_test() {
.await
.unwrap();
let params = CreateChatMessageParams::new_user("where is singapore?");
let stream = test_client
let content = read_text_from_asset("my_profile.txt");
let metadata = ChatMessageMetadata {
data: ChatMetadataData::new_text(content),
id: "123".to_string(),
name: "test context".to_string(),
source: "user added".to_string(),
};
let params =
CreateChatMessageParams::new_user("Where lucas live?").with_metadata(json!(vec![metadata]));
let question = test_client
.api_client
.create_question_answer(&workspace_id, &chat_id, params)
.create_question(&workspace_id, &chat_id, params)
.await
.unwrap();
let messages: Vec<ChatMessage> = stream.map(|message| message.unwrap()).collect().await;
assert_eq!(messages.len(), 2);
let answer = test_client
.api_client
.get_answer(&workspace_id, &chat_id, question.message_id)
.await
.unwrap();
assert!(answer.content.contains("Singapore"));
assert_json_eq!(
answer.meta_data,
json!([
{
"id": "123",
"name": "test context",
"source": "user added",
}
])
);
let related_questions = test_client
.api_client
.get_chat_related_question(&workspace_id, &chat_id, messages[1].message_id)
.get_chat_related_question(&workspace_id, &chat_id, question.message_id)
.await
.unwrap();
assert_eq!(related_questions.items.len(), 3);
@ -154,7 +185,7 @@ async fn generate_chat_message_answer_test() {
let answer = test_client
.api_client
.generate_answer(&workspace_id, &chat_id, messages[0].message_id)
.get_answer(&workspace_id, &chat_id, messages[0].message_id)
.await
.unwrap();
@ -188,23 +219,88 @@ async fn generate_stream_answer_test() {
let params = CreateChatMessageParams::new_user("Teach me how to write a article?");
let question = test_client
.api_client
.save_question(&workspace_id, &chat_id, params)
.create_question(&workspace_id, &chat_id, params)
.await
.unwrap();
// test v1 api endpoint
let mut answer_stream = test_client
.api_client
.ask_question(&workspace_id, &chat_id, question.message_id)
.stream_answer(&workspace_id, &chat_id, question.message_id)
.await
.unwrap();
let mut answer = String::new();
let mut answer_v1 = String::new();
while let Some(message) = answer_stream.next().await {
let message = message.unwrap();
let s = String::from_utf8(message.to_vec()).unwrap();
answer.push_str(&s);
answer_v1.push_str(&s);
}
assert!(!answer.is_empty());
assert!(!answer_v1.is_empty());
// test v2 api endpoint
let mut answer_stream = test_client
.api_client
.stream_answer_v2(&workspace_id, &chat_id, question.message_id)
.await
.unwrap();
let mut answer_v2 = String::new();
while let Some(value) = answer_stream.next().await {
match value.unwrap() {
QuestionStreamValue::Answer { value } => {
answer_v2.push_str(&value);
},
QuestionStreamValue::Metadata { .. } => {},
}
}
assert!(!answer_v2.is_empty());
}
#[tokio::test]
async fn create_chat_context_test() {
let test_client = TestClient::new_user_without_ws_conn().await;
let workspace_id = test_client.workspace_id().await;
let chat_id = uuid::Uuid::new_v4().to_string();
let params = CreateChatParams {
chat_id: chat_id.clone(),
name: "context chat".to_string(),
rag_ids: vec![],
};
test_client
.api_client
.create_chat(&workspace_id, params)
.await
.unwrap();
let context = CreateTextChatContext {
chat_id: chat_id.clone(),
content_type: "txt".to_string(),
text: "I have lived in the US for five years".to_string(),
chunk_size: 1000,
chunk_overlap: 20,
metadata: Default::default(),
};
test_client
.api_client
.create_chat_context(&workspace_id, context)
.await
.unwrap();
let params = CreateChatMessageParams::new_user("Where I live?");
let question = test_client
.api_client
.create_question(&workspace_id, &chat_id, params)
.await
.unwrap();
let answer = test_client
.api_client
.get_answer(&workspace_id, &chat_id, question.message_id)
.await
.unwrap();
assert!(answer.content.contains("US"));
println!("answer: {:?}", answer);
}
// #[tokio::test]

View file

@ -2,3 +2,4 @@ mod chat_test;
mod complete_text;
// mod local_ai_test;
mod summarize_row;
mod util;

9
tests/ai_test/util.rs Normal file
View file

@ -0,0 +1,9 @@
use std::fs::File;
use std::io::Read;
pub(crate) fn read_text_from_asset(file_name: &str) -> String {
let mut file = File::open(format!("./tests/ai_test/asset/{}", file_name)).unwrap();
let mut buffer = Vec::new();
file.read_to_end(&mut buffer).unwrap();
String::from_utf8(buffer).unwrap()
}

View file

@ -96,6 +96,7 @@ async fn chat_message_crud_test(pool: PgPool) {
ChatAuthor::new(0, ChatAuthorType::System),
&chat_id,
format!("message {}", i),
None,
)
.await
.unwrap();