chore: generate ai image if need (#1254)

* chore: generate ai image

* chore: fix test
This commit is contained in:
Nathan.fooo 2025-02-26 22:54:03 +08:00 committed by GitHub
parent 1c38cdd23f
commit 06ddd7f755
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 157 additions and 18 deletions

View file

@ -32,6 +32,16 @@
"ordinal": 5,
"name": "status",
"type_info": "Int2"
},
{
"ordinal": 6,
"name": "source",
"type_info": "Int2"
},
{
"ordinal": 7,
"name": "source_metadata",
"type_info": "Jsonb"
}
],
"parameters": {
@ -46,7 +56,9 @@
false,
false,
false,
false
false,
false,
true
]
},
"hash": "441316f35ca8c24bf78167f9fec48e28c05969bbbbe3d0e3d9e1569a375de476"

View file

@ -32,6 +32,16 @@
"ordinal": 5,
"name": "status",
"type_info": "Int2"
},
{
"ordinal": 6,
"name": "source",
"type_info": "Int2"
},
{
"ordinal": 7,
"name": "source_metadata",
"type_info": "Jsonb"
}
],
"parameters": {
@ -45,7 +55,9 @@
false,
false,
false,
false
false,
false,
true
]
},
"hash": "74de473589a405c3ab567e72a881869321095e2de497b2c1866c547f939c359c"

View file

@ -254,6 +254,19 @@ impl AppFlowyAIClient {
.into_data()
}
pub async fn regenerate_image(&self, source_metadata: Value) -> Result<(), AIError> {
let url = format!("{}/chat/image/regenerate", self.url);
let resp = self
.async_http_client(Method::POST, &url)?
.json(&source_metadata)
.timeout(Duration::from_secs(30))
.send()
.await?;
AIResponse::<()>::from_reqwest_response(resp)
.await?
.into_data()
}
pub async fn get_local_ai_package(
&self,
platform: &str,

View file

@ -202,19 +202,40 @@ pub struct AFCollabMemberRow {
#[repr(i16)]
pub enum AFBlobStatus {
Ok = 0,
DallEContentPolicyViolation = 1,
PolicyViolation = 1,
Failed = 2,
Pending = 3,
}
impl From<i16> for AFBlobStatus {
fn from(value: i16) -> Self {
match value {
0 => AFBlobStatus::Ok,
1 => AFBlobStatus::DallEContentPolicyViolation,
1 => AFBlobStatus::PolicyViolation,
2 => AFBlobStatus::Failed,
3 => AFBlobStatus::Pending,
_ => AFBlobStatus::Ok,
}
}
}
#[derive(Serialize, Deserialize, Eq, PartialEq, Debug, Clone)]
#[repr(i16)]
pub enum AFBlobSource {
UserUpload = 0,
AIGen = 1,
}
impl From<i16> for AFBlobSource {
fn from(value: i16) -> Self {
match value {
0 => AFBlobSource::UserUpload,
1 => AFBlobSource::AIGen,
_ => AFBlobSource::UserUpload,
}
}
}
#[derive(Debug, FromRow, Serialize, Deserialize)]
pub struct AFBlobMetadataRow {
pub workspace_id: Uuid,
@ -224,6 +245,10 @@ pub struct AFBlobMetadataRow {
pub modified_at: DateTime<Utc>,
#[serde(default)]
pub status: i16,
#[serde(default)]
pub source: i16,
#[serde(default)]
pub source_metadata: serde_json::Value,
}
#[derive(Debug, Deserialize, Serialize, Clone)]

View file

@ -0,0 +1,4 @@
-- Add migration script here
ALTER TABLE af_blob_metadata
ADD COLUMN source SMALLINT NOT NULL DEFAULT 0,
ADD COLUMN source_metadata JSONB DEFAULT '{}'::jsonb;

View file

@ -23,9 +23,10 @@ use database_entity::file_dto::{
use crate::biz::data_import::LimitedPayload;
use crate::state::AppState;
use anyhow::anyhow;
use appflowy_ai_client::client::AppFlowyAIClient;
use aws_sdk_s3::primitives::ByteStream;
use collab_importer::util::FileId;
use database::pg_row::AFBlobStatus;
use database::pg_row::{AFBlobSource, AFBlobStatus};
use serde::Deserialize;
use shared_entity::dto::file_dto::PutFileResponse;
use shared_entity::dto::workspace_dto::{BlobMetadata, RepeatedBlobMetaData, WorkspaceSpaceUsage};
@ -35,7 +36,7 @@ use std::pin::Pin;
use tokio::io::{AsyncRead, AsyncReadExt};
use tokio_stream::StreamExt;
use tokio_util::io::StreamReader;
use tracing::{error, event, instrument, trace};
use tracing::{error, event, info, instrument, trace};
pub fn file_storage_scope() -> Scope {
web::scope("/api/file_storage")
@ -370,12 +371,39 @@ async fn get_blob_by_object_key(
}
let metadata = result.unwrap();
match AFBlobStatus::from(metadata.status) {
AFBlobStatus::DallEContentPolicyViolation => {
return Ok(HttpResponse::UnprocessableEntity().finish());
let source = AFBlobSource::from(metadata.source);
trace!("blob metadata: {:?}", metadata);
match source {
AFBlobSource::UserUpload => {},
AFBlobSource::AIGen => {
let spawn_regenerate_image =
|client: AppFlowyAIClient, source_metadata: serde_json::Value| {
tokio::spawn(async move {
info!("Regenerate ai image: {:?}", source_metadata);
let _ = client.regenerate_image(source_metadata).await;
});
};
let source_metadata = metadata.source_metadata;
let status = AFBlobStatus::from(metadata.status);
trace!("AI image {}: {:?}", key.object_key(), status);
match status {
AFBlobStatus::PolicyViolation => {
return Ok(HttpResponse::UnprocessableEntity().finish());
},
AFBlobStatus::Pending => {
if metadata.modified_at + chrono::Duration::minutes(1) < chrono::Utc::now() {
spawn_regenerate_image(state.ai_client.clone(), source_metadata);
} else {
trace!("AI image is pending, wait for 1 minute");
}
},
AFBlobStatus::Failed => {
spawn_regenerate_image(state.ai_client.clone(), source_metadata);
},
_ => {},
};
},
AFBlobStatus::Ok => {},
};
}
// Check if the file is modified since the last time
if let Some(modified_since) = req

View file

@ -276,10 +276,45 @@ async fn generate_chat_message_answer_test() {
.stream_answer_v2(&workspace_id, &chat_id, question.message_id)
.await
.unwrap();
let answer = collect_answer(answer_stream).await;
let answer = collect_answer(answer_stream, None).await;
assert!(!answer.is_empty());
}
// #[tokio::test]
// async fn stop_streaming_test() {
// if !ai_test_enabled() {
// return;
// }
// let test_client = TestClient::new_user_without_ws_conn().await;
// let workspace_id = test_client.workspace_id().await;
// let chat_id = uuid::Uuid::new_v4().to_string();
// let params = CreateChatParams {
// chat_id: chat_id.clone(),
// name: "Stop streaming test".to_string(),
// rag_ids: vec![],
// };
//
// test_client
// .api_client
// .create_chat(&workspace_id, params)
// .await
// .unwrap();
// let params = CreateChatMessageParams::new_user("when to use js");
// let question = test_client
// .api_client
// .create_question(&workspace_id, &chat_id, params)
// .await
// .unwrap();
// let answer_stream = test_client
// .api_client
// .stream_answer_v2(&workspace_id, &chat_id, question.message_id)
// .await
// .unwrap();
// let answer = collect_answer(answer_stream, Some(1)).await;
// println!("answer:\n{}", answer);
// assert!(!answer.is_empty());
// }
#[tokio::test]
async fn get_format_question_message_test() {
if !ai_test_enabled() {
@ -325,7 +360,7 @@ async fn get_format_question_message_test() {
.stream_answer_v3(&workspace_id, query)
.await
.unwrap();
let answer = collect_answer(answer_stream).await;
let answer = collect_answer(answer_stream, None).await;
println!("answer:\n{}", answer);
assert!(!answer.is_empty());
}
@ -380,7 +415,7 @@ async fn get_text_with_image_message_test() {
.stream_answer_v3(&workspace_id, query)
.await
.unwrap();
let answer = collect_answer(answer_stream).await;
let answer = collect_answer(answer_stream, None).await;
println!("answer:\n{}", answer);
let image_url = extract_image_url(&answer).unwrap();
let (workspace_id_2, chat_id_2, file_id_2) = test_client
@ -502,15 +537,25 @@ async fn get_model_list_test() {
println!("models: {:?}", models);
}
async fn collect_answer(mut stream: QuestionStream) -> String {
async fn collect_answer(
mut stream: QuestionStream,
stop_when_num_of_char: Option<usize>,
) -> String {
let mut answer = String::new();
let mut num_of_char: usize = 0;
while let Some(value) = stream.next().await {
match value.unwrap() {
num_of_char += match value.unwrap() {
QuestionStreamValue::Answer { value } => {
answer.push_str(&value);
value.len()
},
QuestionStreamValue::Metadata { .. } => {},
QuestionStreamValue::KeepAlive => {},
QuestionStreamValue::Metadata { .. } => 0,
QuestionStreamValue::KeepAlive => 0,
};
if let Some(stop_when_num_of_char) = stop_when_num_of_char {
if num_of_char >= stop_when_num_of_char {
break;
}
}
}
answer