mirror of
https://github.com/AppFlowy-IO/AppFlowy-Cloud.git
synced 2025-04-19 03:24:42 -04:00
chore: generate ai image if need (#1254)
* chore: generate ai image * chore: fix test
This commit is contained in:
parent
1c38cdd23f
commit
06ddd7f755
7 changed files with 157 additions and 18 deletions
|
@ -32,6 +32,16 @@
|
|||
"ordinal": 5,
|
||||
"name": "status",
|
||||
"type_info": "Int2"
|
||||
},
|
||||
{
|
||||
"ordinal": 6,
|
||||
"name": "source",
|
||||
"type_info": "Int2"
|
||||
},
|
||||
{
|
||||
"ordinal": 7,
|
||||
"name": "source_metadata",
|
||||
"type_info": "Jsonb"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
|
@ -46,7 +56,9 @@
|
|||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
false,
|
||||
false,
|
||||
true
|
||||
]
|
||||
},
|
||||
"hash": "441316f35ca8c24bf78167f9fec48e28c05969bbbbe3d0e3d9e1569a375de476"
|
||||
|
|
|
@ -32,6 +32,16 @@
|
|||
"ordinal": 5,
|
||||
"name": "status",
|
||||
"type_info": "Int2"
|
||||
},
|
||||
{
|
||||
"ordinal": 6,
|
||||
"name": "source",
|
||||
"type_info": "Int2"
|
||||
},
|
||||
{
|
||||
"ordinal": 7,
|
||||
"name": "source_metadata",
|
||||
"type_info": "Jsonb"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
|
@ -45,7 +55,9 @@
|
|||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
false,
|
||||
false,
|
||||
true
|
||||
]
|
||||
},
|
||||
"hash": "74de473589a405c3ab567e72a881869321095e2de497b2c1866c547f939c359c"
|
||||
|
|
|
@ -254,6 +254,19 @@ impl AppFlowyAIClient {
|
|||
.into_data()
|
||||
}
|
||||
|
||||
pub async fn regenerate_image(&self, source_metadata: Value) -> Result<(), AIError> {
|
||||
let url = format!("{}/chat/image/regenerate", self.url);
|
||||
let resp = self
|
||||
.async_http_client(Method::POST, &url)?
|
||||
.json(&source_metadata)
|
||||
.timeout(Duration::from_secs(30))
|
||||
.send()
|
||||
.await?;
|
||||
AIResponse::<()>::from_reqwest_response(resp)
|
||||
.await?
|
||||
.into_data()
|
||||
}
|
||||
|
||||
pub async fn get_local_ai_package(
|
||||
&self,
|
||||
platform: &str,
|
||||
|
|
|
@ -202,19 +202,40 @@ pub struct AFCollabMemberRow {
|
|||
#[repr(i16)]
|
||||
pub enum AFBlobStatus {
|
||||
Ok = 0,
|
||||
DallEContentPolicyViolation = 1,
|
||||
PolicyViolation = 1,
|
||||
Failed = 2,
|
||||
Pending = 3,
|
||||
}
|
||||
|
||||
impl From<i16> for AFBlobStatus {
|
||||
fn from(value: i16) -> Self {
|
||||
match value {
|
||||
0 => AFBlobStatus::Ok,
|
||||
1 => AFBlobStatus::DallEContentPolicyViolation,
|
||||
1 => AFBlobStatus::PolicyViolation,
|
||||
2 => AFBlobStatus::Failed,
|
||||
3 => AFBlobStatus::Pending,
|
||||
_ => AFBlobStatus::Ok,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Eq, PartialEq, Debug, Clone)]
|
||||
#[repr(i16)]
|
||||
pub enum AFBlobSource {
|
||||
UserUpload = 0,
|
||||
AIGen = 1,
|
||||
}
|
||||
|
||||
impl From<i16> for AFBlobSource {
|
||||
fn from(value: i16) -> Self {
|
||||
match value {
|
||||
0 => AFBlobSource::UserUpload,
|
||||
1 => AFBlobSource::AIGen,
|
||||
_ => AFBlobSource::UserUpload,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow, Serialize, Deserialize)]
|
||||
pub struct AFBlobMetadataRow {
|
||||
pub workspace_id: Uuid,
|
||||
|
@ -224,6 +245,10 @@ pub struct AFBlobMetadataRow {
|
|||
pub modified_at: DateTime<Utc>,
|
||||
#[serde(default)]
|
||||
pub status: i16,
|
||||
#[serde(default)]
|
||||
pub source: i16,
|
||||
#[serde(default)]
|
||||
pub source_metadata: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
-- Add migration script here
|
||||
ALTER TABLE af_blob_metadata
|
||||
ADD COLUMN source SMALLINT NOT NULL DEFAULT 0,
|
||||
ADD COLUMN source_metadata JSONB DEFAULT '{}'::jsonb;
|
|
@ -23,9 +23,10 @@ use database_entity::file_dto::{
|
|||
use crate::biz::data_import::LimitedPayload;
|
||||
use crate::state::AppState;
|
||||
use anyhow::anyhow;
|
||||
use appflowy_ai_client::client::AppFlowyAIClient;
|
||||
use aws_sdk_s3::primitives::ByteStream;
|
||||
use collab_importer::util::FileId;
|
||||
use database::pg_row::AFBlobStatus;
|
||||
use database::pg_row::{AFBlobSource, AFBlobStatus};
|
||||
use serde::Deserialize;
|
||||
use shared_entity::dto::file_dto::PutFileResponse;
|
||||
use shared_entity::dto::workspace_dto::{BlobMetadata, RepeatedBlobMetaData, WorkspaceSpaceUsage};
|
||||
|
@ -35,7 +36,7 @@ use std::pin::Pin;
|
|||
use tokio::io::{AsyncRead, AsyncReadExt};
|
||||
use tokio_stream::StreamExt;
|
||||
use tokio_util::io::StreamReader;
|
||||
use tracing::{error, event, instrument, trace};
|
||||
use tracing::{error, event, info, instrument, trace};
|
||||
|
||||
pub fn file_storage_scope() -> Scope {
|
||||
web::scope("/api/file_storage")
|
||||
|
@ -370,12 +371,39 @@ async fn get_blob_by_object_key(
|
|||
}
|
||||
|
||||
let metadata = result.unwrap();
|
||||
match AFBlobStatus::from(metadata.status) {
|
||||
AFBlobStatus::DallEContentPolicyViolation => {
|
||||
return Ok(HttpResponse::UnprocessableEntity().finish());
|
||||
let source = AFBlobSource::from(metadata.source);
|
||||
trace!("blob metadata: {:?}", metadata);
|
||||
match source {
|
||||
AFBlobSource::UserUpload => {},
|
||||
AFBlobSource::AIGen => {
|
||||
let spawn_regenerate_image =
|
||||
|client: AppFlowyAIClient, source_metadata: serde_json::Value| {
|
||||
tokio::spawn(async move {
|
||||
info!("Regenerate ai image: {:?}", source_metadata);
|
||||
let _ = client.regenerate_image(source_metadata).await;
|
||||
});
|
||||
};
|
||||
let source_metadata = metadata.source_metadata;
|
||||
let status = AFBlobStatus::from(metadata.status);
|
||||
trace!("AI image {}: {:?}", key.object_key(), status);
|
||||
match status {
|
||||
AFBlobStatus::PolicyViolation => {
|
||||
return Ok(HttpResponse::UnprocessableEntity().finish());
|
||||
},
|
||||
AFBlobStatus::Pending => {
|
||||
if metadata.modified_at + chrono::Duration::minutes(1) < chrono::Utc::now() {
|
||||
spawn_regenerate_image(state.ai_client.clone(), source_metadata);
|
||||
} else {
|
||||
trace!("AI image is pending, wait for 1 minute");
|
||||
}
|
||||
},
|
||||
AFBlobStatus::Failed => {
|
||||
spawn_regenerate_image(state.ai_client.clone(), source_metadata);
|
||||
},
|
||||
_ => {},
|
||||
};
|
||||
},
|
||||
AFBlobStatus::Ok => {},
|
||||
};
|
||||
}
|
||||
|
||||
// Check if the file is modified since the last time
|
||||
if let Some(modified_since) = req
|
||||
|
|
|
@ -276,10 +276,45 @@ async fn generate_chat_message_answer_test() {
|
|||
.stream_answer_v2(&workspace_id, &chat_id, question.message_id)
|
||||
.await
|
||||
.unwrap();
|
||||
let answer = collect_answer(answer_stream).await;
|
||||
let answer = collect_answer(answer_stream, None).await;
|
||||
assert!(!answer.is_empty());
|
||||
}
|
||||
|
||||
// #[tokio::test]
|
||||
// async fn stop_streaming_test() {
|
||||
// if !ai_test_enabled() {
|
||||
// return;
|
||||
// }
|
||||
// let test_client = TestClient::new_user_without_ws_conn().await;
|
||||
// let workspace_id = test_client.workspace_id().await;
|
||||
// let chat_id = uuid::Uuid::new_v4().to_string();
|
||||
// let params = CreateChatParams {
|
||||
// chat_id: chat_id.clone(),
|
||||
// name: "Stop streaming test".to_string(),
|
||||
// rag_ids: vec![],
|
||||
// };
|
||||
//
|
||||
// test_client
|
||||
// .api_client
|
||||
// .create_chat(&workspace_id, params)
|
||||
// .await
|
||||
// .unwrap();
|
||||
// let params = CreateChatMessageParams::new_user("when to use js");
|
||||
// let question = test_client
|
||||
// .api_client
|
||||
// .create_question(&workspace_id, &chat_id, params)
|
||||
// .await
|
||||
// .unwrap();
|
||||
// let answer_stream = test_client
|
||||
// .api_client
|
||||
// .stream_answer_v2(&workspace_id, &chat_id, question.message_id)
|
||||
// .await
|
||||
// .unwrap();
|
||||
// let answer = collect_answer(answer_stream, Some(1)).await;
|
||||
// println!("answer:\n{}", answer);
|
||||
// assert!(!answer.is_empty());
|
||||
// }
|
||||
|
||||
#[tokio::test]
|
||||
async fn get_format_question_message_test() {
|
||||
if !ai_test_enabled() {
|
||||
|
@ -325,7 +360,7 @@ async fn get_format_question_message_test() {
|
|||
.stream_answer_v3(&workspace_id, query)
|
||||
.await
|
||||
.unwrap();
|
||||
let answer = collect_answer(answer_stream).await;
|
||||
let answer = collect_answer(answer_stream, None).await;
|
||||
println!("answer:\n{}", answer);
|
||||
assert!(!answer.is_empty());
|
||||
}
|
||||
|
@ -380,7 +415,7 @@ async fn get_text_with_image_message_test() {
|
|||
.stream_answer_v3(&workspace_id, query)
|
||||
.await
|
||||
.unwrap();
|
||||
let answer = collect_answer(answer_stream).await;
|
||||
let answer = collect_answer(answer_stream, None).await;
|
||||
println!("answer:\n{}", answer);
|
||||
let image_url = extract_image_url(&answer).unwrap();
|
||||
let (workspace_id_2, chat_id_2, file_id_2) = test_client
|
||||
|
@ -502,15 +537,25 @@ async fn get_model_list_test() {
|
|||
println!("models: {:?}", models);
|
||||
}
|
||||
|
||||
async fn collect_answer(mut stream: QuestionStream) -> String {
|
||||
async fn collect_answer(
|
||||
mut stream: QuestionStream,
|
||||
stop_when_num_of_char: Option<usize>,
|
||||
) -> String {
|
||||
let mut answer = String::new();
|
||||
let mut num_of_char: usize = 0;
|
||||
while let Some(value) = stream.next().await {
|
||||
match value.unwrap() {
|
||||
num_of_char += match value.unwrap() {
|
||||
QuestionStreamValue::Answer { value } => {
|
||||
answer.push_str(&value);
|
||||
value.len()
|
||||
},
|
||||
QuestionStreamValue::Metadata { .. } => {},
|
||||
QuestionStreamValue::KeepAlive => {},
|
||||
QuestionStreamValue::Metadata { .. } => 0,
|
||||
QuestionStreamValue::KeepAlive => 0,
|
||||
};
|
||||
if let Some(stop_when_num_of_char) = stop_when_num_of_char {
|
||||
if num_of_char >= stop_when_num_of_char {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
answer
|
||||
|
|
Loading…
Add table
Reference in a new issue