chore: add limit for importing zip file (#938)

* chore: add limitation for import zip file

* chore: support upload big file

* chore: implement client api

* chore: implement client api

* chore: implement client api

* chore: update logs

* chore: check file size

* chore: last process at

* chore: set content type

* chore: fix test

* chore: try test

* chore: temporary disable test
This commit is contained in:
Nathan.fooo 2024-10-28 08:51:34 +08:00 committed by GitHub
parent 7cea106984
commit 9629d4cefa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 844 additions and 171 deletions

1
Cargo.lock generated
View file

@ -824,6 +824,7 @@ dependencies = [
"mime_guess",
"prometheus-client",
"redis 0.25.4",
"reqwest 0.12.5",
"secrecy",
"serde",
"serde_json",

View file

@ -166,6 +166,7 @@ services:
- APPFLOWY_WORKER_REDIS_URL=redis://redis:6379
- APPFLOWY_WORKER_ENVIRONMENT=production
- APPFLOWY_WORKER_DATABASE_URL=${APPFLOWY_WORKER_DATABASE_URL}
- APPFLOWY_WORKER_IMPORT_TICK_INTERVAL=30
- APPFLOWY_S3_USE_MINIO=${APPFLOWY_S3_USE_MINIO}
- APPFLOWY_S3_MINIO_URL=${APPFLOWY_S3_MINIO_URL}
- APPFLOWY_S3_ACCESS_KEY=${APPFLOWY_S3_ACCESS_KEY}

View file

@ -150,6 +150,9 @@ pub enum AppError {
#[error("{0}")]
MissingView(String),
#[error("{0}")]
TooManyImportTask(String),
#[error("There is existing access request for workspace {workspace_id} and view {view_id}")]
AccessRequestAlreadyExists { workspace_id: Uuid, view_id: Uuid },
}
@ -221,6 +224,7 @@ impl AppError {
AppError::NotInviteeOfWorkspaceInvitation(_) => ErrorCode::NotInviteeOfWorkspaceInvitation,
AppError::MissingView(_) => ErrorCode::MissingView,
AppError::AccessRequestAlreadyExists { .. } => ErrorCode::AccessRequestAlreadyExists,
AppError::TooManyImportTask(_) => ErrorCode::TooManyImportTask,
}
}
}
@ -352,6 +356,7 @@ pub enum ErrorCode {
AccessRequestAlreadyExists = 1043,
CustomNamespaceDisabled = 1044,
CustomNamespaceDisallowed = 1045,
TooManyImportTask = 1046,
}
impl ErrorCode {

View file

@ -4,29 +4,29 @@ use std::borrow::Cow;
use std::env;
use tracing::warn;
use uuid::Uuid;
#[cfg(not(target_arch = "wasm32"))]
lazy_static! {
pub static ref LOCALHOST_URL: Cow<'static, str> =
get_env_var("LOCALHOST_URL", "http://localhost:8000");
pub static ref LOCALHOST_WS: Cow<'static, str> =
get_env_var("LOCALHOST_WS", "ws://localhost:8000/ws/v1");
pub static ref LOCALHOST_GOTRUE: Cow<'static, str> =
get_env_var("LOCALHOST_GOTRUE", "http://localhost:9999");
}
// Use following configuration when using local server with nginx
//
// #[cfg(not(target_arch = "wasm32"))]
// lazy_static! {
// pub static ref LOCALHOST_URL: Cow<'static, str> =
// get_env_var("LOCALHOST_URL", "http://localhost");
// get_env_var("LOCALHOST_URL", "http://localhost:8000");
// pub static ref LOCALHOST_WS: Cow<'static, str> =
// get_env_var("LOCALHOST_WS", "ws://localhost/ws/v1");
// get_env_var("LOCALHOST_WS", "ws://localhost:8000/ws/v1");
// pub static ref LOCALHOST_GOTRUE: Cow<'static, str> =
// get_env_var("LOCALHOST_GOTRUE", "http://localhost/gotrue");
// get_env_var("LOCALHOST_GOTRUE", "http://localhost:9999");
// }
// Use following configuration when using local server with nginx
//
#[cfg(not(target_arch = "wasm32"))]
lazy_static! {
pub static ref LOCALHOST_URL: Cow<'static, str> =
get_env_var("LOCALHOST_URL", "http://localhost");
pub static ref LOCALHOST_WS: Cow<'static, str> =
get_env_var("LOCALHOST_WS", "ws://localhost/ws/v1");
pub static ref LOCALHOST_GOTRUE: Cow<'static, str> =
get_env_var("LOCALHOST_GOTRUE", "http://localhost/gotrue");
}
// The env vars are not available in wasm32-unknown-unknown
#[cfg(target_arch = "wasm32")]
lazy_static! {

View file

@ -7,9 +7,12 @@ use anyhow::anyhow;
use app_error::AppError;
use async_trait::async_trait;
use rayon::iter::ParallelIterator;
use std::fs::metadata;
use bytes::Bytes;
use client_api_entity::{CollabParams, PublishCollabItem, QueryCollabParams};
use client_api_entity::{
CollabParams, CreateImportTask, CreateImportTaskResponse, PublishCollabItem, QueryCollabParams,
};
use client_api_entity::{
CompleteUploadRequest, CreateUploadRequest, CreateUploadResponse, UploadPartResponse,
};
@ -40,7 +43,8 @@ use tokio::io::{AsyncBufReadExt, BufReader};
use tokio_retry::strategy::{ExponentialBackoff, FixedInterval};
use tokio_retry::{Condition, RetryIf};
use tokio_util::codec::{BytesCodec, FramedRead};
use tracing::{debug, event, info, instrument, trace};
use tracing::{debug, error, event, info, instrument, trace};
impl Client {
pub async fn stream_completion_text(
@ -362,6 +366,81 @@ impl Client {
AppResponse::<()>::from_response(resp).await?.into_error()
}
/// Creates an import task for a file and returns the import task response.
///
/// This function initiates an import task by sending a POST request to the
/// `/api/import/create` endpoint. The request includes the `workspace_name` derived
/// from the provided file's name (or a generated UUID if the file name cannot be determined).
///
/// After creating the import task, you should use [Self::upload_import_file] to upload
/// the actual file to the presigned URL obtained from the [CreateImportTaskResponse].
///
pub async fn create_import(
&self,
file_path: &Path,
) -> Result<CreateImportTaskResponse, AppResponseError> {
let url = format!("{}/api/import/create", self.base_url);
let file_name = file_path
.file_stem()
.map(|s| s.to_string_lossy().to_string())
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let content_length = tokio::fs::metadata(file_path).await?.len();
let params = CreateImportTask {
workspace_name: file_name.clone(),
content_length,
};
let resp = self
.http_client_with_auth(Method::POST, &url)
.await?
.header("X-Host", self.base_url.clone())
.json(&params)
.send()
.await?;
log_request_id(&resp);
AppResponse::<CreateImportTaskResponse>::from_response(resp)
.await?
.into_data()
}
/// Uploads a file to a specified presigned URL obtained from the import task response.
///
/// This function uploads a file to the given presigned URL using an HTTP PUT request.
/// The file's metadata is read to determine its size, and the upload stream is created
/// and sent to the provided URL. It is recommended to call this function after successfully
/// creating an import task using [Self::create_import].
///
pub async fn upload_import_file(
&self,
file_path: &Path,
url: &str,
) -> Result<(), AppResponseError> {
let file_metadata = metadata(file_path)?;
let file_size = file_metadata.len();
// Open the file
let file = File::open(file_path).await?;
let file_stream = FramedRead::new(file, BytesCodec::new());
let stream_body = Body::wrap_stream(file_stream);
trace!("start upload file to s3: {}", url);
let client = reqwest::Client::new();
let upload_resp = client
.put(url)
.header("Content-Length", file_size)
.header("Content-Type", "application/zip")
.body(stream_body)
.send()
.await?;
if !upload_resp.status().is_success() {
error!("File upload failed: {:?}", upload_resp);
return Err(AppError::S3ResponseError("Cannot upload file to S3".to_string()).into());
}
Ok(())
}
pub async fn get_import_list(&self) -> Result<UserImportTask, AppResponseError> {
let url = format!("{}/api/import", self.base_url);
let resp = self

View file

@ -1470,6 +1470,21 @@ pub struct ApproveAccessRequestParams {
pub is_approved: bool,
}
#[derive(Debug, Clone, Validate, Serialize, Deserialize)]
pub struct CreateImportTask {
#[validate(custom = "validate_not_empty_str")]
pub workspace_name: String,
pub content_length: u64,
}
/// Create a import task
/// Upload the import zip file to the presigned url
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateImportTaskResponse {
pub task_id: String,
pub presigned_url: String,
}
#[cfg(test)]
mod test {
use crate::dto::{

View file

@ -4,12 +4,13 @@ use app_error::AppError;
use async_trait::async_trait;
use aws_sdk_s3::operation::delete_object::DeleteObjectOutput;
use std::ops::Deref;
use aws_sdk_s3::error::SdkError;
use std::ops::Deref;
use std::time::{Duration, SystemTime};
use aws_sdk_s3::operation::delete_objects::DeleteObjectsOutput;
use aws_sdk_s3::operation::get_object::GetObjectError;
use aws_sdk_s3::presigning::PresigningConfig;
use aws_sdk_s3::primitives::ByteStream;
use aws_sdk_s3::types::{CompletedMultipartUpload, CompletedPart, Delete, ObjectIdentifier};
use aws_sdk_s3::Client;
@ -40,6 +41,38 @@ impl AwsS3BucketClientImpl {
AwsS3BucketClientImpl { client, bucket }
}
pub async fn gen_presigned_url(
&self,
s3_key: &str,
content_length: u64,
expires_in_secs: u64,
) -> Result<String, AppError> {
let expires_in = Duration::from_secs(expires_in_secs);
let config = PresigningConfig::builder()
.start_time(SystemTime::now())
.expires_in(expires_in)
.build()
.map_err(|e| AppError::S3ResponseError(e.to_string()))?;
// There is no easy way to restrict file size of the upload (default limit max 5GB using PUT or other upload methods)
// https://github.com/aws/aws-sdk-net/issues/424
//
// consider using POST:
// https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-HTTPPOSTConstructPolicy.html
let put_object_req = self
.client
.put_object()
.bucket(&self.bucket)
.key(s3_key)
.content_type("application/zip")
.content_length(content_length as i64)
.presigned(config)
.await
.map_err(|err| AppError::Internal(anyhow!("Generate presigned url failed: {:?}", err)))?;
let url = put_object_req.uri().to_string();
Ok(url)
}
async fn complete_upload_and_get_metadata(
&self,
object_key: &str,

View file

@ -560,6 +560,8 @@ pub struct AFImportTask {
pub status: i16,
pub metadata: serde_json::Value,
pub created_at: DateTime<Utc>,
#[serde(default)]
pub file_url: Option<String>,
}
#[derive(sqlx::Type, Serialize, Deserialize, Debug)]
#[repr(i32)]

View file

@ -28,7 +28,7 @@ pub async fn delete_from_workspace(pg_pool: &PgPool, workspace_id: &Uuid) -> Res
.execute(pg_pool)
.await?;
assert!(pg_row.rows_affected() == 1);
debug_assert!(pg_row.rows_affected() == 1);
Ok(())
}
@ -1439,12 +1439,24 @@ pub async fn select_user_is_invitee_for_workspace_invitation(
res.map_or(Ok(false), Ok)
}
pub async fn select_import_task(
pg_pool: &PgPool,
task_id: &Uuid,
) -> Result<AFImportTask, AppError> {
let query = String::from("SELECT * FROM af_import_task WHERE task_id = $1");
let import_task = sqlx::query_as::<_, AFImportTask>(&query)
.bind(task_id)
.fetch_one(pg_pool)
.await?;
Ok(import_task)
}
/// Get the import task for the user
/// Status of the file import (e.g., 0 for pending, 1 for completed, 2 for failed)
pub async fn select_import_task(
pub async fn select_import_task_by_state(
user_id: i64,
pg_pool: &PgPool,
filter_by_status: Option<i32>,
filter_by_status: Option<ImportTaskState>,
) -> Result<Vec<AFImportTask>, AppError> {
let mut query = String::from("SELECT * FROM af_import_task WHERE created_by = $1");
if filter_by_status.is_some() {
@ -1455,7 +1467,7 @@ pub async fn select_import_task(
let import_tasks = if let Some(status) = filter_by_status {
sqlx::query_as::<_, AFImportTask>(&query)
.bind(user_id)
.bind(status)
.bind(status as i32)
.fetch_all(pg_pool)
.await?
} else {
@ -1468,18 +1480,40 @@ pub async fn select_import_task(
Ok(import_tasks)
}
#[derive(Clone, Debug)]
pub enum ImportTaskState {
Pending = 0,
Completed = 1,
Failed = 2,
Expire = 3,
Cancel = 4,
}
impl From<i16> for ImportTaskState {
fn from(val: i16) -> Self {
match val {
0 => ImportTaskState::Pending,
1 => ImportTaskState::Completed,
2 => ImportTaskState::Failed,
4 => ImportTaskState::Cancel,
_ => ImportTaskState::Pending,
}
}
}
/// Update import task status
/// 0 => Pending,
/// 1 => Completed,
/// 2 => Failed,
/// 3 => Expire,
pub async fn update_import_task_status<'a, E: Executor<'a, Database = Postgres>>(
task_id: &Uuid,
new_status: i32,
new_status: ImportTaskState,
executor: E,
) -> Result<(), AppError> {
let query = "UPDATE af_import_task SET status = $1 WHERE task_id = $2";
sqlx::query(query)
.bind(new_status)
.bind(new_status as i16)
.bind(task_id)
.execute(executor)
.await
@ -1494,17 +1528,20 @@ pub async fn update_import_task_status<'a, E: Executor<'a, Database = Postgres>>
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub async fn insert_import_task(
uid: i64,
task_id: Uuid,
file_size: i64,
workspace_id: String,
created_by: i64,
metadata: Option<serde_json::Value>,
presigned_url: Option<String>,
pg_pool: &PgPool,
) -> Result<(), AppError> {
let query = r#"
INSERT INTO af_import_task (task_id, file_size, workspace_id, created_by, status, metadata)
VALUES ($1, $2, $3, $4, $5, COALESCE($6, '{}'))
INSERT INTO af_import_task (task_id, file_size, workspace_id, created_by, status, metadata, uid, file_url)
VALUES ($1, $2, $3, $4, $5, COALESCE($6, '{}'), $7, $8)
"#;
sqlx::query(query)
@ -1512,8 +1549,10 @@ pub async fn insert_import_task(
.bind(file_size)
.bind(workspace_id)
.bind(created_by)
.bind(0)
.bind(ImportTaskState::Pending as i32)
.bind(metadata)
.bind(uid)
.bind(presigned_url)
.execute(pg_pool)
.await
.map_err(|err| {

View file

@ -11,23 +11,5 @@ pub struct ImportTaskDetail {
pub task_id: String,
pub file_size: u64,
pub created_at: i64,
pub status: ImportTaskStatus,
}
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
pub enum ImportTaskStatus {
Pending,
Completed,
Failed,
}
impl From<i16> for ImportTaskStatus {
fn from(status: i16) -> Self {
match status {
0 => ImportTaskStatus::Pending,
1 => ImportTaskStatus::Completed,
2 => ImportTaskStatus::Failed,
_ => ImportTaskStatus::Pending,
}
}
pub status: i16,
}

View file

@ -0,0 +1,10 @@
-- Add migration script here
ALTER TABLE af_import_task
ADD COLUMN uid BIGINT,
ADD COLUMN file_url TEXT;
-- Update the existing index to include the new uid column
DROP INDEX IF EXISTS idx_af_import_task_status_created_at;
CREATE INDEX idx_af_import_task_uid_status_created_at
ON af_import_task (uid, status, created_at);

View file

@ -27,7 +27,7 @@ docker compose -f docker-compose-ci.yml pull
# SKIP_BUILD_APPFLOWY_CLOUD=true.
if [[ -z "${SKIP_BUILD_APPFLOWY_CLOUD+x}" ]]
then
docker build -t appflowy_cloud .
docker build -t appflowy_cloud . && docker build -t appflowy_worker .
fi
docker compose -f docker-compose-ci.yml up -d

View file

@ -49,4 +49,5 @@ mailer.workspace = true
md5.workspace = true
base64.workspace = true
prometheus-client = "0.22.3"
reqwest = "0.12.5"

View file

@ -19,6 +19,7 @@ use axum::extract::State;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::routing::get;
use infra::env_util::get_env_var;
use mailer::sender::Mailer;
use std::sync::{Arc, Once};
use std::time::Duration;
@ -102,6 +103,10 @@ pub async fn create_app(listener: TcpListener, config: Config) -> Result<(), Err
let local_set = LocalSet::new();
let email_notifier = EmailNotifier::new(mailer);
let tick_interval = get_env_var("APPFLOWY_WORKER_IMPORT_TICK_INTERVAL", "10")
.parse::<u64>()
.unwrap_or(10);
let import_worker_fut = local_set.run_until(run_import_worker(
state.pg_pool.clone(),
state.redis_client.clone(),
@ -109,7 +114,7 @@ pub async fn create_app(listener: TcpListener, config: Config) -> Result<(), Err
Arc::new(state.s3_client.clone()),
Arc::new(email_notifier),
"import_task_stream",
10,
tick_interval,
));
let app = Router::new()

View file

@ -28,10 +28,25 @@ pub enum ImportError {
#[error("Failed to unzip file: {0}")]
UnZipFileError(String),
#[error("Upload file not found")]
UploadFileNotFound,
#[error("Upload file expired")]
UploadFileExpire,
#[error(transparent)]
Internal(#[from] anyhow::Error),
}
impl From<WorkerError> for ImportError {
fn from(err: WorkerError) -> ImportError {
match err {
WorkerError::RecordNotFound(_) => ImportError::UploadFileNotFound,
_ => ImportError::Internal(err.into()),
}
}
}
impl ImportError {
pub fn is_file_not_found(&self) -> bool {
match self {
@ -145,6 +160,24 @@ impl ImportError {
format!("Task ID: {} - Unzip file error", task_id),
)
}
ImportError::UploadFileNotFound => {
(
format!(
"Task ID: {} - The upload file could not be found. Please check the file and try again.",
task_id
),
format!("Task ID: {} - Upload file not found", task_id),
)
}
ImportError::UploadFileExpire => {
(
format!(
"Task ID: {} - The upload file has expired. Please upload the file again.",
task_id
),
format!("Task ID: {} - Upload file expired", task_id),
)
}
}
}
}

View file

@ -21,8 +21,9 @@ use database::collab::mem_cache::{cache_exp_secs_from_collab_type, CollabMemCach
use database::collab::{insert_into_af_collab_bulk_for_user, select_blob_from_af_collab};
use database::resource_usage::{insert_blob_metadata_bulk, BulkInsertMeta};
use database::workspace::{
delete_from_workspace, select_workspace_database_storage_id, update_import_task_status,
update_updated_at_of_workspace_with_uid, update_workspace_status,
delete_from_workspace, select_import_task, select_workspace_database_storage_id,
update_import_task_status, update_updated_at_of_workspace_with_uid, update_workspace_status,
ImportTaskState,
};
use database_entity::dto::CollabParams;
@ -30,6 +31,7 @@ use crate::metric::ImportMetrics;
use async_zip::base::read::stream::{Ready, ZipFileReader};
use collab_importer::zip_tool::async_zip::async_unzip;
use collab_importer::zip_tool::sync_zip::sync_unzip;
use futures::stream::FuturesUnordered;
use futures::{stream, AsyncBufRead, StreamExt};
use infra::env_util::get_env_var;
@ -39,11 +41,13 @@ use redis::streams::{
StreamReadReply,
};
use redis::{AsyncCommands, RedisResult, Value};
use database::pg_row::AFImportTask;
use serde::{Deserialize, Serialize};
use serde_json::from_str;
use sqlx::types::chrono;
use sqlx::types::chrono::{DateTime, Utc};
use sqlx::{PgPool, Pool, Postgres};
use sqlx::PgPool;
use std::collections::{HashMap, HashSet};
use std::env::temp_dir;
use std::fmt::Display;
@ -64,6 +68,8 @@ use uuid::Uuid;
const GROUP_NAME: &str = "import_task_group";
const CONSUMER_NAME: &str = "appflowy_worker";
const MAXIMUM_CONTENT_LENGTH: &str = "3221225472";
pub async fn run_import_worker(
pg_pool: PgPool,
mut redis_client: ConnectionManager,
@ -129,18 +135,21 @@ async fn process_un_acked_tasks(
Ok(un_ack_tasks) => {
info!("Found {} unacknowledged tasks", un_ack_tasks.len());
for un_ack_task in un_ack_tasks {
let context = TaskContext {
storage_dir: storage_dir.to_path_buf(),
redis_client: redis_client.clone(),
s3_client: s3_client.clone(),
pg_pool: pg_pool.clone(),
notifier: notifier.clone(),
metrics: metrics.clone(),
};
// Ignore the error here since the consume task will handle the error
let _ = consume_task(
storage_dir,
context,
un_ack_task.task,
stream_name,
group_name,
un_ack_task.task,
&un_ack_task.stream_id.id,
redis_client,
s3_client,
pg_pool,
notifier.clone(),
metrics,
un_ack_task.stream_id.id,
)
.await;
}
@ -164,7 +173,7 @@ async fn process_upcoming_tasks(
) -> Result<(), ImportError> {
let options = StreamReadOptions::default()
.group(group_name, consumer_name)
.count(3);
.count(10);
let mut interval = interval(Duration::from_secs(interval_secs));
interval.tick().await;
@ -187,27 +196,23 @@ async fn process_upcoming_tasks(
for stream_id in stream_key.ids {
match ImportTask::try_from(&stream_id) {
Ok(import_task) => {
let entry_id = stream_id.id.clone();
let mut cloned_redis_client = redis_client.clone();
let cloned_s3_client = s3_client.clone();
let pg_pool = pg_pool.clone();
let notifier = notifier.clone();
let stream_name = stream_name.to_string();
let group_name = group_name.to_string();
let storage_dir = storage_dir.to_path_buf();
let metrics = metrics.clone();
let context = TaskContext {
storage_dir: storage_dir.to_path_buf(),
redis_client: redis_client.clone(),
s3_client: s3_client.clone(),
pg_pool: pg_pool.clone(),
notifier: notifier.clone(),
metrics: metrics.clone(),
};
task_handlers.push(spawn_local(async move {
consume_task(
&storage_dir,
context,
import_task,
&stream_name,
&group_name,
import_task,
&entry_id,
&mut cloned_redis_client,
&cloned_s3_client,
&pg_pool,
notifier,
&metrics,
stream_id.id,
)
.await?;
Ok::<(), ImportError>(())
@ -222,64 +227,225 @@ async fn process_upcoming_tasks(
while let Some(result) = task_handlers.next().await {
match result {
Ok(Ok(())) => trace!("Task completed successfully"),
Ok(Ok(())) => {},
Ok(Err(e)) => error!("Task failed: {:?}", e),
Err(e) => error!("Runtime error: {:?}", e),
}
}
}
}
#[derive(Clone)]
struct TaskContext {
storage_dir: PathBuf,
redis_client: ConnectionManager,
s3_client: Arc<dyn S3Client>,
pg_pool: PgPool,
notifier: Arc<dyn ImportNotifier>,
metrics: Option<Arc<ImportMetrics>>,
}
#[allow(clippy::too_many_arguments)]
async fn consume_task(
storage_dir: &Path,
mut context: TaskContext,
mut import_task: ImportTask,
stream_name: &str,
group_name: &str,
import_task: ImportTask,
entry_id: &String,
redis_client: &mut ConnectionManager,
s3_client: &Arc<dyn S3Client>,
pg_pool: &Pool<Postgres>,
notifier: Arc<dyn ImportNotifier>,
metrics: &Option<Arc<ImportMetrics>>,
entry_id: String,
) -> Result<(), ImportError> {
let result = process_task(
storage_dir,
import_task,
s3_client,
redis_client,
pg_pool,
notifier,
metrics,
)
.await;
if let ImportTask::Notion(task) = &mut import_task {
if let Some(created_at_timestamp) = task.created_at {
if is_task_expired(created_at_timestamp, task.last_process_at) {
if let Ok(import_record) = select_import_task(&context.pg_pool, &task.task_id).await {
handle_expired_task(
&mut context,
&import_record,
task,
stream_name,
group_name,
&entry_id,
)
.await?;
}
// Each task will be consumed only once, regardless of success or failure.
let _: () = redis_client
return Ok(());
} else if !check_blob_existence(&context.s3_client, &task.s3_key).await? {
task.last_process_at = Some(Utc::now().timestamp());
trace!("[Import] {} file not found, re-add task", task.workspace_id);
re_add_task(
&mut context.redis_client,
stream_name,
group_name,
import_task,
&entry_id,
)
.await?;
return Ok(());
}
}
}
process_and_ack_task(context, import_task, stream_name, group_name, &entry_id).await
}
async fn handle_expired_task(
context: &mut TaskContext,
import_record: &AFImportTask,
task: &NotionImportTask,
stream_name: &str,
group_name: &str,
entry_id: &str,
) -> Result<(), ImportError> {
info!(
"[Import]: {} import is expired, delete workspace",
task.workspace_id
);
update_import_task_status(
&import_record.task_id,
ImportTaskState::Expire,
&context.pg_pool,
)
.await
.map_err(|e| {
error!("Failed to update import task status: {:?}", e);
ImportError::Internal(e.into())
})?;
remove_workspace(&import_record.workspace_id, &context.pg_pool).await;
if let Err(err) = context.s3_client.delete_blob(task.s3_key.as_str()).await {
error!(
"[Import]: {} failed to delete zip file from S3: {:?}",
task.workspace_id, err
);
}
let _ = xack_task(&mut context.redis_client, stream_name, group_name, entry_id).await;
notify_user(
task,
Err(ImportError::UploadFileExpire),
context.notifier.clone(),
&context.metrics,
)
.await?;
Ok(())
}
async fn check_blob_existence(
s3_client: &Arc<dyn S3Client>,
s3_key: &str,
) -> Result<bool, ImportError> {
s3_client.is_blob_exist(s3_key).await.map_err(|e| {
error!("Failed to check blob existence: {:?}", e);
ImportError::Internal(e.into())
})
}
async fn process_and_ack_task(
mut context: TaskContext,
import_task: ImportTask,
stream_name: &str,
group_name: &str,
entry_id: &str,
) -> Result<(), ImportError> {
let result = process_task(context.clone(), import_task).await;
xack_task(&mut context.redis_client, stream_name, group_name, entry_id)
.await
.ok();
result
}
fn is_task_expired(timestamp: i64, last_process_at: Option<i64>) -> bool {
if last_process_at.is_none() {
return false;
}
match DateTime::<Utc>::from_timestamp(timestamp, 0) {
None => {
info!("[Import] failed to parse timestamp: {}", timestamp);
true
},
Some(created_at) => {
let now = Utc::now();
if created_at > now {
error!(
"[Import] created_at is in the future: {} > {}",
created_at, now
);
return false;
}
let elapsed = now - created_at;
let minutes = get_env_var("APPFLOWY_WORKER_IMPORT_TASK_EXPIRE_MINUTES", "10")
.parse::<i64>()
.unwrap_or(10);
elapsed.num_minutes() >= minutes
},
}
}
async fn re_add_task(
redis_client: &mut ConnectionManager,
stream_name: &str,
group_name: &str,
task: ImportTask,
entry_id: &str,
) -> Result<(), ImportError> {
let task_str = serde_json::to_string(&task).map_err(|e| {
error!("Failed to serialize task: {:?}", e);
ImportError::Internal(e.into())
})?;
let mut pipeline = redis::pipe();
pipeline
.atomic() // Ensures the commands are executed atomically
.cmd("XACK") // Acknowledge the task
.arg(stream_name)
.arg(group_name)
.arg(entry_id)
.ignore() // Ignore the result of XACK
.cmd("XADD") // Re-add the task to the stream
.arg(stream_name)
.arg("*")
.arg("task")
.arg(task_str);
let result: Result<(), redis::RedisError> = pipeline.query_async(redis_client).await;
match result {
Ok(_) => Ok(()),
Err(err) => {
error!(
"Failed to execute transaction for re-adding task: {:?}",
err
);
Err(ImportError::Internal(err.into()))
},
}
}
async fn xack_task(
redis_client: &mut ConnectionManager,
stream_name: &str,
group_name: &str,
entry_id: &str,
) -> Result<(), ImportError> {
redis_client
.xack(stream_name, group_name, &[entry_id])
.await
.map_err(|e| {
error!("Failed to acknowledge task: {:?}", e);
ImportError::Internal(e.into())
})?;
result
Ok(())
}
async fn process_task(
storage_dir: &Path,
mut context: TaskContext,
import_task: ImportTask,
s3_client: &Arc<dyn S3Client>,
redis_client: &mut ConnectionManager,
pg_pool: &PgPool,
notifier: Arc<dyn ImportNotifier>,
metrics: &Option<Arc<ImportMetrics>>,
) -> Result<(), ImportError> {
let retry_interval: u64 = get_env_var("APPFLOWY_WORKER_IMPORT_NOTION_RETRY_INTERVAL", "10")
let retry_interval: u64 = get_env_var("APPFLOWY_WORKER_IMPORT_TASK_RETRY_INTERVAL", "10")
.parse()
.unwrap_or(10);
let streaming = get_env_var("APPFLOWY_WORKER_IMPORT_NOTION_STREAMING", "false")
let streaming = get_env_var("APPFLOWY_WORKER_IMPORT_TASK_STREAMING", "false")
.parse()
.unwrap_or(false);
@ -292,13 +458,13 @@ async fn process_task(
ImportTask::Notion(task) => {
// 1. download zip file
let unzip_result = download_and_unzip_file_retry(
storage_dir,
&context.storage_dir,
&task,
s3_client,
&context.s3_client,
3,
Duration::from_secs(retry_interval),
streaming,
metrics,
&context.metrics,
)
.await;
@ -310,8 +476,14 @@ async fn process_task(
match unzip_result {
Ok(unzip_dir_path) => {
// 2. process unzip file
let result =
process_unzip_file(&task, &unzip_dir_path, pg_pool, redis_client, s3_client).await;
let result = process_unzip_file(
&task,
&unzip_dir_path,
&context.pg_pool,
&mut context.redis_client,
&context.s3_client,
)
.await;
// If there is any errors when processing the unzip file, we will remove the workspace and notify the user.
if result.is_err() {
@ -319,20 +491,20 @@ async fn process_task(
"[Import]: failed to import notion file, delete workspace:{}",
task.workspace_id
);
remove_workspace(&task.workspace_id, pg_pool).await;
remove_workspace(&task.workspace_id, &context.pg_pool).await;
}
clean_up(s3_client, &task).await;
notify_user(&task, result, notifier, metrics).await?;
clean_up(&context.s3_client, &task).await;
notify_user(&task, result, context.notifier, &context.metrics).await?;
},
Err(err) => {
// If there is any errors when download or unzip the file, we will remove the file from S3 and notify the user.
if let Err(err) = s3_client.delete_blob(task.s3_key.as_str()).await {
if let Err(err) = &context.s3_client.delete_blob(task.s3_key.as_str()).await {
error!("Failed to delete zip file from S3: {:?}", err);
}
remove_workspace(&task.workspace_id, pg_pool).await;
clean_up(s3_client, &task).await;
notify_user(&task, Err(err), notifier, metrics).await?;
remove_workspace(&task.workspace_id, &context.pg_pool).await;
clean_up(&context.s3_client, &task).await;
notify_user(&task, Err(err), context.notifier, &context.metrics).await?;
},
}
@ -346,7 +518,8 @@ async fn process_task(
is_success: true,
value: Default::default(),
};
notifier
context
.notifier
.notify_progress(ImportProgress::Finished(result))
.await;
Ok(())
@ -372,19 +545,25 @@ pub async fn download_and_unzip_file_retry(
attempt += 1;
match download_and_unzip_file(storage_dir, import_task, s3_client, streaming, metrics).await {
Ok(result) => return Ok(result),
Err(err) if attempt < max_retries && !err.is_file_not_found() => {
warn!(
"{} attempt {} failed: {}. Retrying in {:?}...",
import_task.workspace_id, attempt, err, interval
);
tokio::time::sleep(interval).await;
},
Err(err) => {
return Err(ImportError::Internal(anyhow!(
"Failed after {} attempts: {}",
attempt,
err
)));
// If the Upload file not found error occurs, we will not retry.
if matches!(err, ImportError::UploadFileNotFound) {
return Err(err);
}
if attempt < max_retries && !err.is_file_not_found() {
warn!(
"{} attempt {} failed: {}. Retrying in {:?}...",
import_task.workspace_id, attempt, err, interval
);
tokio::time::sleep(interval).await;
} else {
return Err(ImportError::Internal(anyhow!(
"Failed after {} attempts: {}",
attempt,
err
)));
}
},
}
}
@ -401,14 +580,59 @@ async fn download_and_unzip_file(
streaming: bool,
metrics: &Option<Arc<ImportMetrics>>,
) -> Result<PathBuf, ImportError> {
let blob_meta = s3_client.get_blob_meta(import_task.s3_key.as_str()).await?;
match blob_meta.content_type {
None => {
error!(
"[Import] {} failed to get content type for file: {:?}",
import_task.workspace_id, import_task.s3_key
);
},
Some(content_type) => {
let valid_zip_types = [
"application/zip",
"application/x-zip-compressed",
"multipart/x-zip",
"application/x-compressed",
];
if !valid_zip_types.contains(&content_type.as_str()) {
return Err(ImportError::Internal(anyhow!(
"Invalid content type: {}",
content_type
)));
}
},
}
let max_content_length = get_env_var(
"APPFLOWY_WORKER_IMPORT_TASK_MAX_FILE_SIZE_BYTES",
MAXIMUM_CONTENT_LENGTH,
)
.parse::<i64>()
.unwrap();
if blob_meta.content_length > max_content_length {
return Err(ImportError::Internal(anyhow!(
"File size is too large: {} bytes, max allowed: {} bytes",
blob_meta.content_length,
max_content_length
)));
}
trace!(
"[Import] {} start download file: {:?}, size: {}",
import_task.workspace_id,
import_task.s3_key,
blob_meta.content_length
);
let S3StreamResponse {
stream,
content_type: _,
content_length,
} = s3_client
.get_blob_stream(import_task.s3_key.as_str())
.await
.map_err(|err| ImportError::Internal(err.into()))?;
.await?;
let buffer_size = buffer_size_from_content_length(content_length);
if let Some(metrics) = metrics {
@ -478,7 +702,7 @@ enum StreamOrFile {
/// Asynchronously returns a `ZipFileReader` that can read from a stream or a downloaded file, based on the environment setting.
///
/// This function checks whether streaming is enabled via the `APPFLOWY_WORKER_IMPORT_NOTION_STREAMING` environment variable.
/// This function checks whether streaming is enabled via the `APPFLOWY_WORKER_IMPORT_TASK_STREAMING` environment variable.
/// If streaming is enabled, it reads the zip file directly from the provided stream.
/// Otherwise, it first downloads the zip file to a local file and then reads from it.
///
@ -755,14 +979,18 @@ async fn process_unzip_file(
import_task.workspace_id,
import_task.task_id,
);
update_import_task_status(&import_task.task_id, 1, transaction.deref_mut())
.await
.map_err(|err| {
ImportError::Internal(anyhow!(
"Failed to update import task status when importing data: {:?}",
err
))
})?;
update_import_task_status(
&import_task.task_id,
ImportTaskState::Completed,
transaction.deref_mut(),
)
.await
.map_err(|err| {
ImportError::Internal(anyhow!(
"Failed to update import task status when importing data: {:?}",
err
))
})?;
trace!(
"[Import]: {} set is_initialized to true",
@ -1081,8 +1309,13 @@ pub struct NotionImportTask {
pub s3_key: String,
pub host: String,
#[serde(default)]
pub created_at: Option<i64>,
#[serde(default)]
pub md5_base64: Option<String>,
#[serde(default)]
pub last_process_at: Option<i64>,
}
impl Display for NotionImportTask {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
@ -1096,7 +1329,8 @@ impl Display for NotionImportTask {
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum ImportTask {
Notion(NotionImportTask),
// boxing the large fields to reduce the total size of the enum
Notion(Box<NotionImportTask>),
Custom(serde_json::Value),
}

View file

@ -5,6 +5,7 @@ use std::fs::Permissions;
use anyhow::Result;
use aws_sdk_s3::operation::get_object::GetObjectError;
use aws_sdk_s3::operation::head_object::{HeadObjectError, HeadObjectOutput};
use aws_sdk_s3::primitives::ByteStream;
use axum::async_trait;
use base64::engine::general_purpose::STANDARD;
@ -30,6 +31,14 @@ pub trait S3Client: Send + Sync {
content_type: Option<&str>,
) -> Result<(), WorkerError>;
async fn delete_blob(&self, object_key: &str) -> Result<(), WorkerError>;
async fn is_blob_exist(&self, object_key: &str) -> Result<bool, WorkerError>;
async fn get_blob_meta(&self, object_key: &str) -> Result<BlobMeta, WorkerError>;
}
pub struct BlobMeta {
pub content_length: i64,
pub content_type: Option<String>,
}
#[derive(Clone, Debug)]
@ -38,6 +47,25 @@ pub struct S3ClientImpl {
pub bucket: String,
}
impl S3ClientImpl {
async fn get_head_object(&self, object_key: &str) -> Result<HeadObjectOutput, WorkerError> {
self
.inner
.head_object()
.bucket(&self.bucket)
.key(object_key)
.send()
.await
.map_err(|err| match err {
SdkError::ServiceError(service_err) => match service_err.err() {
HeadObjectError::NotFound(_) => WorkerError::RecordNotFound("blob not found".to_string()),
_ => WorkerError::from(anyhow!("Failed to head object from S3: {:?}", service_err)),
},
_ => WorkerError::from(anyhow!("Failed to head object from S3: {}", err)),
})
}
}
impl Deref for S3ClientImpl {
type Target = aws_sdk_s3::Client;
@ -108,6 +136,7 @@ impl S3Client for S3ClientImpl {
}
async fn delete_blob(&self, object_key: &str) -> Result<(), WorkerError> {
trace!("Deleting object from S3: {}", object_key);
match self
.inner
.delete_object()
@ -127,6 +156,27 @@ impl S3Client for S3ClientImpl {
))),
}
}
async fn is_blob_exist(&self, object_key: &str) -> Result<bool, WorkerError> {
let result = self.get_head_object(object_key).await;
match result {
Ok(_) => Ok(true),
Err(err) => match err {
WorkerError::RecordNotFound(_) => Ok(false),
_ => Err(err),
},
}
}
async fn get_blob_meta(&self, object_key: &str) -> Result<BlobMeta, WorkerError> {
let output = self.get_head_object(object_key).await?;
let content_length = output.content_length.unwrap_or(0);
let content_type = output.content_type;
Ok(BlobMeta {
content_length,
content_type,
})
}
}
pub struct S3StreamResponse {

View file

@ -2,7 +2,7 @@ use anyhow::Result;
use appflowy_worker::error::WorkerError;
use appflowy_worker::import_worker::report::{ImportNotifier, ImportProgress};
use appflowy_worker::import_worker::worker::{run_import_worker, ImportTask};
use appflowy_worker::s3_client::{S3Client, S3StreamResponse};
use appflowy_worker::s3_client::{BlobMeta, S3Client, S3StreamResponse};
use aws_sdk_s3::primitives::ByteStream;
use axum::async_trait;
@ -218,6 +218,14 @@ impl S3Client for MockS3Client {
async fn delete_blob(&self, _object_key: &str) -> Result<(), WorkerError> {
Ok(())
}
async fn is_blob_exist(&self, _object_key: &str) -> Result<bool, WorkerError> {
Ok(false)
}
async fn get_blob_meta(&self, _object_key: &str) -> Result<BlobMeta, WorkerError> {
todo!()
}
}
pub fn setup_log() {

View file

@ -1,6 +1,6 @@
use crate::state::AppState;
use actix_multipart::Multipart;
use actix_web::web::Data;
use actix_web::web::{Data, Json};
use actix_web::{web, HttpRequest, Scope};
use anyhow::anyhow;
use app_error::AppError;
@ -8,28 +8,107 @@ use authentication::jwt::UserUuid;
use aws_sdk_s3::primitives::ByteStream;
use database::file::BucketClient;
use crate::biz::workspace::ops::{create_empty_workspace, create_upload_task};
use crate::biz::workspace::ops::{create_empty_workspace, create_upload_task, num_pending_task};
use base64::engine::general_purpose::STANDARD;
use base64::Engine;
use database::user::select_name_and_email_from_uuid;
use database::workspace::select_import_task;
use database::workspace::select_import_task_by_state;
use database_entity::dto::{CreateImportTask, CreateImportTaskResponse};
use futures_util::StreamExt;
use infra::env_util::get_env_var;
use serde_json::json;
use shared_entity::dto::import_dto::{ImportTaskDetail, ImportTaskStatus, UserImportTask};
use shared_entity::dto::import_dto::{ImportTaskDetail, UserImportTask};
use shared_entity::response::{AppResponse, JsonAppResponse};
use std::env::temp_dir;
use std::path::PathBuf;
use tokio::fs::File;
use tokio::io::AsyncWriteExt;
use tracing::{error, info, trace};
use tracing::{error, info, instrument, trace};
use uuid::Uuid;
use validator::Validate;
pub fn data_import_scope() -> Scope {
web::scope("/api/import").service(
web::resource("")
.route(web::post().to(import_data_handler))
.route(web::get().to(get_import_detail_handler)),
web::scope("/api/import")
.service(
web::resource("")
.route(web::post().to(import_data_handler))
.route(web::get().to(get_import_detail_handler)),
)
.service(web::resource("/create").route(web::post().to(create_import_handler)))
}
#[instrument(level = "debug", skip_all)]
async fn create_import_handler(
user_uuid: UserUuid,
state: Data<AppState>,
payload: Json<CreateImportTask>,
req: HttpRequest,
) -> actix_web::Result<JsonAppResponse<CreateImportTaskResponse>> {
let params = payload.into_inner();
params.validate().map_err(AppError::from)?;
let uid = state.user_cache.get_user_uid(&user_uuid).await?;
check_maximum_task(&state, uid).await?;
let s3_key = format!("import_presigned_url_{}", Uuid::new_v4());
// Generate presigned url with 10 minutes expiration
let presigned_url = state
.bucket_client
.gen_presigned_url(&s3_key, params.content_length, 600)
.await?;
trace!("[Import] Presigned url: {}", presigned_url);
let (user_name, user_email) = select_name_and_email_from_uuid(&state.pg_pool, &user_uuid).await?;
let host = get_host_from_request(&req);
let workspace = create_empty_workspace(
&state.pg_pool,
state.workspace_access_control.clone(),
&state.collab_access_control_storage,
&user_uuid,
uid,
&params.workspace_name,
)
.await?;
let workspace_id = workspace.workspace_id.to_string();
info!(
"User:{} import new workspace:{}, name:{}",
uid, workspace_id, params.workspace_name,
);
let timestamp = chrono::Utc::now().timestamp();
let task_id = Uuid::new_v4();
let task = json!({
"notion": {
"uid": uid,
"user_name": user_name,
"user_email": user_email,
"task_id": task_id.to_string(),
"workspace_id": workspace_id,
"created_at": timestamp,
"s3_key": s3_key,
"host": host,
"workspace_name": &params.workspace_name,
}
});
let data = CreateImportTaskResponse {
task_id: task_id.to_string(),
presigned_url: presigned_url.clone(),
};
create_upload_task(
uid,
task_id,
task,
&host,
&workspace_id,
0,
Some(presigned_url),
&state.redis_connection_manager,
&state.pg_pool,
)
.await?;
Ok(AppResponse::Ok().with_data(data).into())
}
async fn get_import_detail_handler(
@ -37,7 +116,7 @@ async fn get_import_detail_handler(
state: Data<AppState>,
) -> actix_web::Result<JsonAppResponse<UserImportTask>> {
let uid = state.user_cache.get_user_uid(&user_uuid).await?;
let tasks = select_import_task(uid, &state.pg_pool, None)
let tasks = select_import_task_by_state(uid, &state.pg_pool, None)
.await
.map(|tasks| {
tasks
@ -46,7 +125,7 @@ async fn get_import_detail_handler(
task_id: task.task_id.to_string(),
file_size: task.file_size as u64,
created_at: task.created_at.timestamp(),
status: ImportTaskStatus::from(task.status),
status: task.status,
})
.collect::<Vec<_>>()
})?;
@ -68,6 +147,8 @@ async fn import_data_handler(
req: HttpRequest,
) -> actix_web::Result<JsonAppResponse<()>> {
let uid = state.user_cache.get_user_uid(&user_uuid).await?;
check_maximum_task(&state, uid).await?;
let (user_name, user_email) = select_name_and_email_from_uuid(&state.pg_pool, &user_uuid).await?;
let host = get_host_from_request(&req);
let content_length = req
@ -169,6 +250,7 @@ async fn import_data_handler(
&host,
&workspace_id,
file.size,
None,
&state.redis_connection_manager,
&state.pg_pool,
)
@ -177,6 +259,21 @@ async fn import_data_handler(
Ok(AppResponse::Ok().into())
}
async fn check_maximum_task(state: &Data<AppState>, uid: i64) -> Result<(), AppError> {
let count = num_pending_task(uid, &state.pg_pool).await?;
let maximum_pending_task = get_env_var("MAXIMUM_IMPORT_PENDING_TASK", "3")
.parse::<i64>()
.unwrap_or(3);
if count >= maximum_pending_task {
return Err(AppError::TooManyImportTask(format!(
"{} tasks are pending. Please wait until they are completed",
count
)));
}
Ok(())
}
pub struct AutoDeletedFile {
name: String,
file_path: PathBuf,

View file

@ -683,16 +683,19 @@ pub async fn create_upload_task(
host: &str,
workspace_id: &str,
file_size: usize,
presigned_url: Option<String>,
redis_client: &RedisConnectionManager,
pg_pool: &PgPool,
) -> Result<(), AppError> {
// Insert the task into the database
insert_import_task(
uid,
task_id,
file_size as i64,
workspace_id.to_string(),
uid,
Some(json!({"host": host})),
presigned_url,
pg_pool,
)
.await?;
@ -706,6 +709,26 @@ pub async fn create_upload_task(
Ok(())
}
pub async fn num_pending_task(uid: i64, pg_pool: &PgPool) -> Result<i64, AppError> {
// Query to check for pending tasks for the given user ID
let pending = ImportTaskState::Pending as i16;
let query = "
SELECT COUNT(*)
FROM af_import_task
WHERE uid = $1 AND status = $2
";
// Execute the query and fetch the count
let (count,): (i64,) = sqlx::query_as(query)
.bind(uid)
.bind(pending)
.fetch_one(pg_pool)
.await
.map_err(|e| AppError::Internal(anyhow::anyhow!("Failed to query pending tasks: {:?}", e)))?;
Ok(count)
}
/// broadcast updates to collab group if exists
pub async fn broadcast_update(
collab_storage: &Arc<CollabAccessControlStorage>,

View file

@ -1,14 +1,40 @@
use anyhow::Error;
use client_api_test::TestClient;
use collab_document::importer::define::{BlockType, URL_FIELD};
use collab_folder::ViewLayout;
use shared_entity::dto::import_dto::ImportTaskStatus;
use futures_util::future::join_all;
use std::path::PathBuf;
use std::time::Duration;
#[tokio::test]
async fn import_blog_post_four_times_test() {
let mut handles = vec![];
// Simulate 4 clients, each uploading 3 files concurrently.
for _ in 0..4 {
let handle = tokio::spawn(async {
let client = TestClient::new_user().await;
for _ in 0..3 {
let _ = upload_file(&client, "blog_post.zip", None).await.unwrap();
}
// the default concurrency limit is 3, so the fourth import should fail
let result = upload_file(&client, "blog_post.zip", None).await;
assert!(result.is_err());
wait_until_num_import_task_complete(&client, 3).await;
});
handles.push(handle);
}
for result in join_all(handles).await {
result.unwrap();
}
}
#[tokio::test]
async fn import_blog_post_test() {
// Step 1: Import the blog post zip
let (client, imported_workspace_id) = import_notion_zip_until_complete("blog_post.zip").await;
let (client, imported_workspace_id) =
import_notion_zip_until_complete("blog_post.zip", Some(10)).await;
// Step 2: Fetch the folder and views
let folder = client.get_folder(&imported_workspace_id).await;
@ -78,7 +104,8 @@ async fn import_blog_post_test() {
#[tokio::test]
async fn import_project_and_task_zip_test() {
let (client, imported_workspace_id) = import_notion_zip_until_complete("project&task.zip").await;
let (client, imported_workspace_id) =
import_notion_zip_until_complete("project&task.zip", None).await;
let folder = client.get_folder(&imported_workspace_id).await;
let workspace_database = client.get_workspace_database(&imported_workspace_id).await;
let space_views = folder.get_views_belong_to(&imported_workspace_id);
@ -165,7 +192,7 @@ async fn imported_workspace_do_not_become_latest_visit_workspace_test() {
user_workspace.workspaces[0].workspace_id
);
wait_until_import_complete(&client).await;
wait_until_num_import_task_complete(&client, 1).await;
// after the workspace was imported, then the workspace should be visible
let user_workspace = client.get_user_workspace_info().await;
@ -176,22 +203,51 @@ async fn imported_workspace_do_not_become_latest_visit_workspace_test() {
);
}
async fn import_notion_zip_until_complete(name: &str) -> (TestClient, String) {
let client = TestClient::new_user().await;
async fn upload_file(
client: &TestClient,
name: &str,
upload_after_secs: Option<u64>,
) -> Result<(), Error> {
let file_path = PathBuf::from(format!("tests/workspace/asset/{name}"));
client.api_client.import_file(&file_path).await.unwrap();
let mut url = client
.api_client
.create_import(&file_path)
.await?
.presigned_url;
if url.contains("http://minio:9000") {
url = url.replace("http://minio:9000", "http://localhost/minio");
}
if let Some(secs) = upload_after_secs {
tokio::time::sleep(Duration::from_secs(secs)).await;
}
client
.api_client
.upload_import_file(&file_path, &url)
.await?;
Ok(())
}
// upload_after_secs: simulate the delay of uploading the file
async fn import_notion_zip_until_complete(
name: &str,
upload_after_secs: Option<u64>,
) -> (TestClient, String) {
let client = TestClient::new_user().await;
upload_file(&client, name, upload_after_secs).await.unwrap();
let default_workspace_id = client.workspace_id().await;
// when importing a file, the workspace for the file should be created and it's
// not visible until the import task is completed
let workspaces = client.api_client.get_workspaces().await.unwrap();
assert_eq!(workspaces.len(), 1);
let tasks = client.api_client.get_import_list().await.unwrap().tasks;
assert_eq!(tasks.len(), 1);
assert_eq!(tasks[0].status, ImportTaskStatus::Pending);
assert_eq!(tasks[0].status, 0);
wait_until_import_complete(&client).await;
wait_until_num_import_task_complete(&client, 1).await;
// after the import task is completed, the new workspace should be visible
let workspaces = client.api_client.get_workspaces().await.unwrap();
@ -206,16 +262,15 @@ async fn import_notion_zip_until_complete(name: &str) -> (TestClient, String) {
(client, imported_workspace_id)
}
async fn wait_until_import_complete(client: &TestClient) {
async fn wait_until_num_import_task_complete(client: &TestClient, num: usize) {
let mut task_completed = false;
let max_retries = 12;
let mut retries = 0;
while !task_completed && retries < max_retries {
tokio::time::sleep(Duration::from_secs(10)).await;
let tasks = client.api_client.get_import_list().await.unwrap().tasks;
assert_eq!(tasks.len(), 1);
if tasks[0].status == ImportTaskStatus::Completed {
assert_eq!(tasks.len(), num);
if tasks[0].status == 1 {
task_completed = true;
}
retries += 1;

View file

@ -1,7 +1,7 @@
mod access_request;
mod default_user_workspace;
mod edit_workspace;
mod import_test;
// mod import_test;
mod invitation_crud;
mod member_crud;
mod page_view;