S3 collab cache (#1028)

* chore: collab cache for S3

* chore: adjust disk cache api to accomodate s3

* chore: move postgres dependent ops to disk cache

* chore: replace blob inserts from pg to s3

* chore: delete blob and collab exist now use s3

* chore: fix clippy erorrs

* chore: post rebase fixes

* chore: fix clippy warnings

* chore: fix imports

* chore: make snapshots work over S3

* chore: remove dead code

* chore: use compressed snapshots

* chore: add zstd compression

* chore: introduce collab size threshold to keep smaller collabs in postgres

* chore: remove collabs from S3 if they were put to postgres

* chore: update tests
This commit is contained in:
Bartosz Sypytkowski 2024-12-03 06:08:55 +01:00 committed by GitHub
parent 6718a7afb1
commit 9ff6f1c744
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 1111 additions and 944 deletions

View file

@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT * FROM af_collab_snapshot\n WHERE sid = $1 AND deleted_at IS NULL;\n ",
"query": "\n SELECT * FROM af_collab_snapshot\n WHERE sid = $1 AND oid = $2 AND workspace_id = $3 AND deleted_at IS NULL;\n ",
"describe": {
"columns": [
{
@ -46,7 +46,9 @@
],
"parameters": {
"Left": [
"Int8"
"Int8",
"Text",
"Uuid"
]
},
"nullable": [
@ -60,5 +62,5 @@
false
]
},
"hash": "1f04da964eb7bd99b6cd5016f27d8ca0d3635933e4c681cbf3591e52a9b06663"
"hash": "21f66ca39be3377f8c5e4b218123e266fe8e03260ecd1891c644820892dda2b2"
}

5
Cargo.lock generated
View file

@ -697,6 +697,7 @@ dependencies = [
"validator",
"workspace-template",
"yrs",
"zstd 0.13.2",
]
[[package]]
@ -714,6 +715,7 @@ dependencies = [
"async-stream",
"async-trait",
"authentication",
"aws-sdk-s3",
"brotli 3.5.0",
"bytes",
"chrono",
@ -761,6 +763,7 @@ dependencies = [
"validator",
"workspace-template",
"yrs",
"zstd 0.13.2",
]
[[package]]
@ -802,6 +805,7 @@ dependencies = [
"tracing",
"tracing-subscriber",
"uuid",
"zstd 0.13.2",
]
[[package]]
@ -2921,6 +2925,7 @@ dependencies = [
"tracing",
"uuid",
"validator",
"zstd 0.13.2",
]
[[package]]

View file

@ -21,6 +21,7 @@ actix-router = "0.5.2"
actix-session = { version = "0.8", features = ["redis-rs-tls-session"] }
actix-multipart = { version = "0.7.2", features = ["derive"] }
openssl = { version = "0.10.62", features = ["vendored"] }
zstd.workspace = true
# serde
serde_json.workspace = true
@ -282,6 +283,7 @@ sanitize-filename = "0.5.0"
base64 = "0.22"
md5 = "0.7.0"
pin-project = "1.1.5"
zstd = { version = "0.13.2", features = [] }
# collaboration
yrs = { version = "0.21.3", features = ["sync"] }

View file

@ -19,6 +19,13 @@ use tracing::error;
use uuid::Uuid;
use validator::Validate;
/// The default compression level of ZSTD-compressed collabs.
pub const ZSTD_COMPRESSION_LEVEL: i32 = 3;
/// The threshold used to determine whether collab data should land
/// in S3 or Postgres. Collabs with size below this value will land into Postgres.
pub const S3_COLLAB_THRESHOLD: usize = 2000;
#[derive(Debug, Clone, Validate, Serialize, Deserialize)]
pub struct CreateCollabParams {
#[validate(custom = "validate_not_empty_str")]
@ -67,6 +74,22 @@ impl CreateCollabParams {
pub struct CollabIndexParams {}
pub struct PendingCollabWrite {
pub workspace_id: String,
pub uid: i64,
pub params: CollabParams,
}
impl PendingCollabWrite {
pub fn new(workspace_id: String, uid: i64, params: CollabParams) -> Self {
PendingCollabWrite {
workspace_id,
uid,
params,
}
}
}
#[derive(Debug, Clone, Validate, Serialize, Deserialize, PartialEq)]
pub struct CollabParams {
#[validate(custom = "validate_not_empty_str")]
@ -206,12 +229,12 @@ pub struct DeleteCollabParams {
pub workspace_id: String,
}
#[derive(Debug, Clone, Validate, Serialize, Deserialize)]
#[derive(Debug, Clone, Validate)]
pub struct InsertSnapshotParams {
#[validate(custom = "validate_not_empty_str")]
pub object_id: String,
#[validate(custom = "validate_not_empty_payload")]
pub encoded_collab_v1: Vec<u8>,
pub data: Bytes,
#[validate(custom = "validate_not_empty_str")]
pub workspace_id: String,
pub collab_type: CollabType,
@ -233,8 +256,6 @@ pub struct QuerySnapshotParams {
pub struct QueryCollabParams {
#[validate(custom = "validate_not_empty_str")]
pub workspace_id: String,
#[serde(flatten)]
#[validate]
pub inner: QueryCollab,
}

View file

@ -20,6 +20,7 @@ anyhow = "1.0.79"
serde.workspace = true
serde_json.workspace = true
tonic-proto.workspace = true
zstd.workspace = true
sqlx = { workspace = true, default-features = false, features = [
"postgres",

View file

@ -1,19 +1,17 @@
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use collab::entity::EncodedCollab;
use collab_entity::CollabType;
use futures_util::{stream, StreamExt};
use itertools::{Either, Itertools};
use sqlx::{PgPool, Transaction};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tracing::{error, event, Level};
use crate::collab::disk_cache::CollabDiskCache;
use crate::collab::mem_cache::{cache_exp_secs_from_collab_type, CollabMemCache};
use crate::collab::CollabMetadata;
use crate::file::s3_client_impl::AwsS3BucketClientImpl;
use app_error::AppError;
use database_entity::dto::{CollabParams, QueryCollab, QueryCollabResult};
use database_entity::dto::{CollabParams, PendingCollabWrite, QueryCollab, QueryCollabResult};
#[derive(Clone)]
pub struct CollabCache {
@ -24,9 +22,13 @@ pub struct CollabCache {
}
impl CollabCache {
pub fn new(redis_conn_manager: redis::aio::ConnectionManager, pg_pool: PgPool) -> Self {
pub fn new(
redis_conn_manager: redis::aio::ConnectionManager,
pg_pool: PgPool,
s3: AwsS3BucketClientImpl,
) -> Self {
let mem_cache = CollabMemCache::new(redis_conn_manager.clone());
let disk_cache = CollabDiskCache::new(pg_pool.clone());
let disk_cache = CollabDiskCache::new(pg_pool.clone(), s3);
Self {
disk_cache,
mem_cache,
@ -35,37 +37,49 @@ impl CollabCache {
}
}
pub async fn get_collab_meta(
pub async fn bulk_insert_collab(
&self,
object_id: &str,
collab_type: &CollabType,
) -> Result<CollabMetadata, AppError> {
match self.mem_cache.get_collab_meta(object_id).await {
Ok(meta) => Ok(meta),
Err(_) => {
let row = self
.disk_cache
.get_collab_meta(object_id, collab_type)
.await?;
let meta = CollabMetadata {
object_id: row.oid,
workspace_id: row.workspace_id.to_string(),
};
workspace_id: &str,
uid: &i64,
params_list: Vec<CollabParams>,
) -> Result<(), AppError> {
self
.disk_cache
.bulk_insert_collab(workspace_id, uid, params_list.clone())
.await?;
// Spawn a background task to insert the collaboration metadata into the memory cache.
let cloned_meta = meta.clone();
let mem_cache = self.mem_cache.clone();
tokio::spawn(async move {
if let Err(err) = mem_cache.insert_collab_meta(cloned_meta).await {
error!("{:?}", err);
}
});
Ok(meta)
},
}
// update the mem cache without blocking the current task
let mem_cache = self.mem_cache.clone();
tokio::spawn(async move {
let timestamp = chrono::Utc::now().timestamp();
for params in params_list {
if let Err(err) = mem_cache
.insert_encode_collab_data(
&params.object_id,
&params.encoded_collab_v1,
timestamp,
Some(cache_exp_secs_from_collab_type(&params.collab_type)),
)
.await
.map_err(|err| AppError::Internal(err.into()))
{
tracing::warn!(
"Failed to insert collab `{}` into memory cache: {}",
params.object_id,
err
);
}
}
});
Ok(())
}
pub async fn get_encode_collab(&self, query: QueryCollab) -> Result<EncodedCollab, AppError> {
pub async fn get_encode_collab(
&self,
workspace_id: &str,
query: QueryCollab,
) -> Result<EncodedCollab, AppError> {
self.total_attempts.fetch_add(1, Ordering::Relaxed);
// Attempt to retrieve encoded collab from memory cache, falling back to disk cache if necessary.
if let Some(encoded_collab) = self.mem_cache.get_encode_collab(&query.object_id).await {
@ -81,7 +95,10 @@ impl CollabCache {
// Retrieve from disk cache as fallback. After retrieval, the value is inserted into the memory cache.
let object_id = query.object_id.clone();
let expiration_secs = cache_exp_secs_from_collab_type(&query.collab_type);
let encode_collab = self.disk_cache.get_collab_encoded_from_disk(query).await?;
let encode_collab = self
.disk_cache
.get_collab_encoded_from_disk(workspace_id, query)
.await?;
// spawn a task to insert the encoded collab into the memory cache
let cloned_encode_collab = encode_collab.clone();
@ -99,6 +116,7 @@ impl CollabCache {
/// returns a hashmap of the object_id to the encoded collab data.
pub async fn batch_get_encode_collab<T: Into<QueryCollab>>(
&self,
workspace_id: &str,
queries: Vec<T>,
) -> HashMap<String, QueryCollabResult> {
let queries = queries.into_iter().map(Into::into).collect::<Vec<_>>();
@ -129,7 +147,10 @@ impl CollabCache {
// 2. Retrieves remaining values from the disk cache for queries not satisfied by the memory cache.
// - These values are then merged into the final result set.
let values_from_disk_cache = self.disk_cache.batch_get_collab(disk_queries).await;
let values_from_disk_cache = self
.disk_cache
.batch_get_collab(workspace_id, disk_queries)
.await;
results.extend(values_from_disk_cache);
results
}
@ -140,15 +161,14 @@ impl CollabCache {
&self,
workspace_id: &str,
uid: &i64,
params: &CollabParams,
params: CollabParams,
transaction: &mut Transaction<'_, sqlx::Postgres>,
) -> Result<(), AppError> {
let collab_type = params.collab_type.clone();
let object_id = params.object_id.clone();
let encode_collab_data = params.encoded_collab_v1.clone();
self
.disk_cache
.upsert_collab_with_transaction(workspace_id, uid, params, transaction)
let s3 = self.disk_cache.s3_client();
CollabDiskCache::upsert_collab_with_transaction(workspace_id, uid, params, transaction, s3)
.await?;
// when the data is written to the disk cache but fails to be written to the memory cache
@ -176,9 +196,13 @@ impl CollabCache {
pub async fn get_encode_collab_from_disk(
&self,
workspace_id: &str,
query: QueryCollab,
) -> Result<EncodedCollab, AppError> {
let encode_collab = self.disk_cache.get_collab_encoded_from_disk(query).await?;
let encode_collab = self
.disk_cache
.get_collab_encoded_from_disk(workspace_id, query)
.await?;
Ok(encode_collab)
}
@ -187,30 +211,14 @@ impl CollabCache {
workspace_id: &str,
uid: &i64,
params: CollabParams,
transaction: &mut Transaction<'_, sqlx::Postgres>,
) -> Result<(), AppError> {
self
.disk_cache
.upsert_collab_with_transaction(workspace_id, uid, &params, transaction)
.upsert_collab(workspace_id, uid, params)
.await?;
Ok(())
}
pub async fn insert_encode_collab_to_mem(&self, params: &CollabParams) -> Result<(), AppError> {
let timestamp = chrono::Utc::now().timestamp();
self
.mem_cache
.insert_encode_collab_data(
&params.object_id,
&params.encoded_collab_v1,
timestamp,
Some(cache_exp_secs_from_collab_type(&params.collab_type)),
)
.await
.map_err(|err| AppError::Internal(err.into()))?;
Ok(())
}
pub fn query_state(&self) -> QueryState {
let success_attempts = self.success_attempts.load(Ordering::Relaxed);
let total_attempts = self.total_attempts.load(Ordering::Relaxed);
@ -220,25 +228,31 @@ impl CollabCache {
}
}
pub async fn delete_collab(&self, object_id: &str) -> Result<(), AppError> {
pub async fn delete_collab(&self, workspace_id: &str, object_id: &str) -> Result<(), AppError> {
self.mem_cache.remove_encode_collab(object_id).await?;
self.disk_cache.delete_collab(object_id).await?;
self
.disk_cache
.delete_collab(workspace_id, object_id)
.await?;
Ok(())
}
pub async fn is_exist(&self, oid: &str) -> Result<bool, AppError> {
pub async fn is_exist(&self, workspace_id: &str, oid: &str) -> Result<bool, AppError> {
if let Ok(value) = self.mem_cache.is_exist(oid).await {
if value {
return Ok(value);
}
}
let is_exist = self.disk_cache.is_exist(oid).await?;
let is_exist = self.disk_cache.is_exist(workspace_id, oid).await?;
Ok(is_exist)
}
pub fn pg_pool(&self) -> &sqlx::PgPool {
&self.disk_cache.pg_pool
pub async fn batch_insert_collab(
&self,
records: Vec<PendingCollabWrite>,
) -> Result<u64, AppError> {
self.disk_cache.batch_insert_collab(records).await
}
}

View file

@ -287,8 +287,8 @@ where
pub async fn batch_select_collab_blob(
pg_pool: &PgPool,
queries: Vec<QueryCollab>,
) -> HashMap<String, QueryCollabResult> {
let mut results = HashMap::new();
results: &mut HashMap<String, QueryCollabResult>,
) {
let mut object_ids_by_collab_type: HashMap<CollabType, Vec<String>> = HashMap::new();
for params in queries {
object_ids_by_collab_type
@ -337,8 +337,6 @@ pub async fn batch_select_collab_blob(
Err(err) => error!("Batch get collab errors: {}", err),
}
}
results
}
#[derive(Debug, sqlx::FromRow)]
@ -461,15 +459,20 @@ pub async fn create_snapshot_and_maintain_limit<'a>(
#[inline]
pub async fn select_snapshot(
pg_pool: &PgPool,
workspace_id: &str,
object_id: &str,
snapshot_id: &i64,
) -> Result<Option<AFSnapshotRow>, Error> {
let workspace_id = Uuid::from_str(workspace_id).map_err(|err| Error::Decode(err.into()))?;
let row = sqlx::query_as!(
AFSnapshotRow,
r#"
SELECT * FROM af_collab_snapshot
WHERE sid = $1 AND deleted_at IS NULL;
WHERE sid = $1 AND oid = $2 AND workspace_id = $3 AND deleted_at IS NULL;
"#,
snapshot_id,
object_id,
workspace_id
)
.fetch_optional(pg_pool)
.await?;

View file

@ -7,12 +7,10 @@ use database_entity::dto::{
};
use collab::entity::EncodedCollab;
use collab_entity::CollabType;
use collab_rt_entity::ClientCollabMessage;
use serde::{Deserialize, Serialize};
use sqlx::Transaction;
use std::collections::HashMap;
use std::sync::Arc;
pub const COLLAB_SNAPSHOT_LIMIT: i64 = 30;
pub const SNAPSHOT_PER_HOUR: i64 = 6;
@ -134,6 +132,7 @@ pub trait CollabStorage: Send + Sync + 'static {
async fn batch_get_collab(
&self,
uid: &i64,
workspace_id: &str,
queries: Vec<QueryCollab>,
from_editing_collab: bool,
) -> HashMap<String, QueryCollabResult>;
@ -149,12 +148,7 @@ pub trait CollabStorage: Send + Sync + 'static {
/// * `Result<()>` - Returns `Ok(())` if the collaboration was deleted successfully, `Err` otherwise.
async fn delete_collab(&self, workspace_id: &str, uid: &i64, object_id: &str) -> AppResult<()>;
async fn query_collab_meta(
&self,
object_id: &str,
collab_type: &CollabType,
) -> AppResult<CollabMetadata>;
async fn should_create_snapshot(&self, oid: &str) -> Result<bool, AppError>;
async fn should_create_snapshot(&self, workspace_id: &str, oid: &str) -> Result<bool, AppError>;
async fn create_snapshot(&self, params: InsertSnapshotParams) -> AppResult<AFSnapshotMeta>;
async fn queue_snapshot(&self, params: InsertSnapshotParams) -> AppResult<()>;
@ -167,143 +161,11 @@ pub trait CollabStorage: Send + Sync + 'static {
) -> AppResult<SnapshotData>;
/// Returns list of snapshots for given object_id in descending order of creation time.
async fn get_collab_snapshot_list(&self, oid: &str) -> AppResult<AFSnapshotMetas>;
}
#[async_trait]
impl<T> CollabStorage for Arc<T>
where
T: CollabStorage,
{
fn encode_collab_redis_query_state(&self) -> (u64, u64) {
self.as_ref().encode_collab_redis_query_state()
}
async fn queue_insert_or_update_collab(
async fn get_collab_snapshot_list(
&self,
workspace_id: &str,
uid: &i64,
params: CollabParams,
write_immediately: bool,
) -> AppResult<()> {
self
.as_ref()
.queue_insert_or_update_collab(workspace_id, uid, params, write_immediately)
.await
}
async fn batch_insert_new_collab(
&self,
workspace_id: &str,
uid: &i64,
params: Vec<CollabParams>,
) -> AppResult<()> {
self
.as_ref()
.batch_insert_new_collab(workspace_id, uid, params)
.await
}
async fn insert_new_collab_with_transaction(
&self,
workspace_id: &str,
uid: &i64,
params: CollabParams,
transaction: &mut Transaction<'_, sqlx::Postgres>,
action_description: &str,
) -> AppResult<()> {
self
.as_ref()
.insert_new_collab_with_transaction(
workspace_id,
uid,
params,
transaction,
action_description,
)
.await
}
async fn get_encode_collab(
&self,
origin: GetCollabOrigin,
params: QueryCollabParams,
from_editing_collab: bool,
) -> AppResult<EncodedCollab> {
self
.as_ref()
.get_encode_collab(origin, params, from_editing_collab)
.await
}
async fn broadcast_encode_collab(
&self,
object_id: String,
collab_messages: Vec<ClientCollabMessage>,
) -> Result<(), AppError> {
self
.as_ref()
.broadcast_encode_collab(object_id, collab_messages)
.await
}
async fn batch_get_collab(
&self,
uid: &i64,
queries: Vec<QueryCollab>,
from_editing_collab: bool,
) -> HashMap<String, QueryCollabResult> {
self
.as_ref()
.batch_get_collab(uid, queries, from_editing_collab)
.await
}
async fn delete_collab(&self, workspace_id: &str, uid: &i64, object_id: &str) -> AppResult<()> {
self
.as_ref()
.delete_collab(workspace_id, uid, object_id)
.await
}
async fn query_collab_meta(
&self,
object_id: &str,
collab_type: &CollabType,
) -> AppResult<CollabMetadata> {
self
.as_ref()
.query_collab_meta(object_id, collab_type)
.await
}
async fn should_create_snapshot(&self, oid: &str) -> Result<bool, AppError> {
self.as_ref().should_create_snapshot(oid).await
}
async fn create_snapshot(&self, params: InsertSnapshotParams) -> AppResult<AFSnapshotMeta> {
self.as_ref().create_snapshot(params).await
}
async fn queue_snapshot(&self, params: InsertSnapshotParams) -> AppResult<()> {
self.as_ref().queue_snapshot(params).await
}
async fn get_collab_snapshot(
&self,
workspace_id: &str,
object_id: &str,
snapshot_id: &i64,
) -> AppResult<SnapshotData> {
self
.as_ref()
.get_collab_snapshot(workspace_id, object_id, snapshot_id)
.await
}
async fn get_collab_snapshot_list(&self, oid: &str) -> AppResult<AFSnapshotMetas> {
self.as_ref().get_collab_snapshot_list(oid).await
}
oid: &str,
) -> AppResult<AFSnapshotMetas>;
}
#[derive(Debug, Clone, Deserialize, Serialize)]

View file

@ -1,61 +1,114 @@
use std::collections::HashMap;
use std::time::Duration;
use collab::entity::EncodedCollab;
use anyhow::{anyhow, Context};
use bytes::Bytes;
use collab::entity::{EncodedCollab, EncoderVersion};
use collab_entity::CollabType;
use sqlx::{Error, PgPool, Transaction};
use std::collections::HashMap;
use std::ops::DerefMut;
use std::time::{Duration, Instant};
use tokio::task::JoinSet;
use tokio::time::sleep;
use tracing::{event, instrument, Level};
use tracing::{error, instrument};
use uuid::Uuid;
use crate::collab::util::encode_collab_from_bytes;
use crate::collab::{
batch_select_collab_blob, insert_into_af_collab, is_collab_exists, select_blob_from_af_collab,
select_collab_meta_from_af_collab, AppResult,
batch_select_collab_blob, insert_into_af_collab, insert_into_af_collab_bulk_for_user,
is_collab_exists, select_blob_from_af_collab, AppResult,
};
use crate::file::s3_client_impl::AwsS3BucketClientImpl;
use crate::file::{BucketClient, ResponseBlob};
use crate::index::upsert_collab_embeddings;
use crate::pg_row::AFCollabRowMeta;
use app_error::AppError;
use database_entity::dto::{CollabParams, QueryCollab, QueryCollabResult};
use database_entity::dto::{
CollabParams, PendingCollabWrite, QueryCollab, QueryCollabResult, S3_COLLAB_THRESHOLD,
ZSTD_COMPRESSION_LEVEL,
};
#[derive(Clone)]
pub struct CollabDiskCache {
pub pg_pool: PgPool,
pg_pool: PgPool,
s3: AwsS3BucketClientImpl,
}
impl CollabDiskCache {
pub fn new(pg_pool: PgPool) -> Self {
Self { pg_pool }
pub fn new(pg_pool: PgPool, s3: AwsS3BucketClientImpl) -> Self {
Self { pg_pool, s3 }
}
pub async fn is_exist(&self, object_id: &str) -> AppResult<bool> {
let is_exist = is_collab_exists(object_id, &self.pg_pool).await?;
Ok(is_exist)
}
pub async fn get_collab_meta(
&self,
object_id: &str,
collab_type: &CollabType,
) -> AppResult<AFCollabRowMeta> {
let result = select_collab_meta_from_af_collab(&self.pg_pool, object_id, collab_type).await?;
match result {
None => {
let msg = format!("Can't find the row for object_id: {}", object_id);
Err(AppError::RecordNotFound(msg))
},
Some(meta) => Ok(meta),
pub async fn is_exist(&self, workspace_id: &str, object_id: &str) -> AppResult<bool> {
let dir = collab_key_prefix(workspace_id, object_id);
let resp = self.s3.list_dir(&dir, 1).await?;
if resp.is_empty() {
// fallback to Postgres
Ok(is_collab_exists(object_id, &self.pg_pool).await?)
} else {
Ok(true)
}
}
pub async fn upsert_collab_with_transaction(
pub async fn upsert_collab(
&self,
workspace_id: &str,
uid: &i64,
params: &CollabParams,
transaction: &mut Transaction<'_, sqlx::Postgres>,
params: CollabParams,
) -> AppResult<()> {
insert_into_af_collab(transaction, uid, workspace_id, params).await?;
// Start a database transaction
let mut transaction = self
.pg_pool
.begin()
.await
.context("Failed to acquire transaction for writing pending collaboration data")
.map_err(AppError::from)?;
Self::upsert_collab_with_transaction(
workspace_id,
uid,
params,
&mut transaction,
self.s3.clone(),
)
.await?;
tokio::time::timeout(Duration::from_secs(10), transaction.commit())
.await
.map_err(|_| {
AppError::Internal(anyhow!(
"Timeout when committing the transaction for pending collaboration data"
))
})??;
Ok(())
}
pub fn s3_client(&self) -> AwsS3BucketClientImpl {
self.s3.clone()
}
pub async fn upsert_collab_with_transaction(
workspace_id: &str,
uid: &i64,
mut params: CollabParams,
transaction: &mut Transaction<'_, sqlx::Postgres>,
s3: AwsS3BucketClientImpl,
) -> AppResult<()> {
let mut delete_from_s3 = Vec::new();
let key = collab_key(workspace_id, &params.object_id);
if params.encoded_collab_v1.len() > S3_COLLAB_THRESHOLD {
// put collab into S3
let encoded_collab = std::mem::take(&mut params.encoded_collab_v1);
tokio::spawn(Self::insert_blob_with_retries(
s3.clone(),
key,
encoded_collab,
3,
));
} else {
// put collab into Postgres (and remove outdated version from S3)
delete_from_s3.push(key);
}
insert_into_af_collab(transaction, uid, workspace_id, &params).await?;
if let Some(em) = &params.embeddings {
tracing::info!(
"saving collab {} embeddings (cost: {} tokens)",
@ -70,23 +123,56 @@ impl CollabDiskCache {
em.params.clone(),
)
.await?;
if !delete_from_s3.is_empty() {
tokio::spawn(async move {
if let Err(err) = s3.delete_blobs(delete_from_s3).await {
tracing::warn!("failed to delete outdated collab from S3: {}", err);
}
});
}
} else if params.collab_type == CollabType::Document {
tracing::info!("no embeddings to save for collab {}", params.object_id);
}
Ok(())
}
#[instrument(level = "trace", skip_all)]
pub async fn get_collab_encoded_from_disk(
&self,
workspace_id: &str,
query: QueryCollab,
) -> Result<EncodedCollab, AppError> {
event!(
Level::DEBUG,
"try get {}:{} from disk",
query.collab_type,
query.object_id
);
tracing::debug!("try get {}:{} from s3", query.collab_type, query.object_id);
let key = collab_key(workspace_id, &query.object_id);
match self.s3.get_blob(&key).await {
Ok(resp) => {
let blob = resp.to_blob();
let now = Instant::now();
let decompressed = zstd::decode_all(&*blob)?;
tracing::trace!(
"decompressed collab {}B -> {}B in {:?}",
blob.len(),
decompressed.len(),
now.elapsed()
);
return Ok(EncodedCollab {
state_vector: Default::default(),
doc_state: decompressed.into(),
version: EncoderVersion::V1,
});
},
Err(AppError::RecordNotFound(_)) => {
tracing::debug!(
"try get {}:{} from database",
query.collab_type,
query.object_id
);
},
Err(err) => {
return Err(err);
},
}
const MAX_ATTEMPTS: usize = 3;
let mut attempts = 0;
@ -121,14 +207,117 @@ impl CollabDiskCache {
}
}
pub async fn batch_get_collab(
//FIXME: this and `batch_insert_collab` duplicate similar logic.
pub async fn bulk_insert_collab(
&self,
queries: Vec<QueryCollab>,
) -> HashMap<String, QueryCollabResult> {
batch_select_collab_blob(&self.pg_pool, queries).await
workspace_id: &str,
uid: &i64,
mut params_list: Vec<CollabParams>,
) -> Result<(), AppError> {
if params_list.is_empty() {
return Ok(());
}
let mut delete_from_s3 = Vec::new();
let mut blobs = HashMap::new();
for param in params_list.iter_mut() {
let key = collab_key(workspace_id, &param.object_id);
if param.encoded_collab_v1.len() > S3_COLLAB_THRESHOLD {
let blob = std::mem::take(&mut param.encoded_collab_v1);
blobs.insert(key, blob);
} else {
// put collab into Postgres (and remove outdated version from S3)
delete_from_s3.push(key);
}
}
let mut transaction = self.pg_pool.begin().await?;
insert_into_af_collab_bulk_for_user(&mut transaction, uid, workspace_id, &params_list).await?;
transaction.commit().await?;
batch_put_collab_to_s3(&self.s3, blobs).await?;
if !delete_from_s3.is_empty() {
self.s3.delete_blobs(delete_from_s3).await?;
}
Ok(())
}
pub async fn delete_collab(&self, object_id: &str) -> AppResult<()> {
pub async fn batch_insert_collab(
&self,
records: Vec<PendingCollabWrite>,
) -> Result<u64, AppError> {
if records.is_empty() {
return Ok(0);
}
let s3 = self.s3.clone();
// Start a database transaction
let mut transaction = self
.pg_pool
.begin()
.await
.context("Failed to acquire transaction for writing pending collaboration data")
.map_err(AppError::from)?;
let mut successful_writes = 0;
// Insert each record into the database within the transaction context
let mut action_description = String::new();
for (index, record) in records.into_iter().enumerate() {
let params = record.params;
action_description = format!("{}", params);
let savepoint_name = format!("sp_{}", index);
// using savepoint to rollback the transaction if the insert fails
sqlx::query(&format!("SAVEPOINT {}", savepoint_name))
.execute(transaction.deref_mut())
.await?;
if let Err(_err) = Self::upsert_collab_with_transaction(
&record.workspace_id,
&record.uid,
params,
&mut transaction,
s3.clone(),
)
.await
{
sqlx::query(&format!("ROLLBACK TO SAVEPOINT {}", savepoint_name))
.execute(transaction.deref_mut())
.await?;
} else {
successful_writes += 1;
}
}
// Commit the transaction to finalize all writes
match tokio::time::timeout(Duration::from_secs(10), transaction.commit()).await {
Ok(result) => {
result.map_err(AppError::from)?;
},
Err(_) => {
error!(
"Timeout waiting for committing the transaction for pending write:{}",
action_description
);
return Err(AppError::Internal(anyhow!(
"Timeout when committing the transaction for pending collaboration data"
)));
},
}
Ok(successful_writes)
}
pub async fn batch_get_collab(
&self,
workspace_id: &str,
queries: Vec<QueryCollab>,
) -> HashMap<String, QueryCollabResult> {
let mut results = HashMap::new();
let not_found = batch_get_collab_from_s3(&self.s3, workspace_id, queries, &mut results).await;
batch_select_collab_blob(&self.pg_pool, not_found, &mut results).await;
results
}
pub async fn delete_collab(&self, workspace_id: &str, object_id: &str) -> AppResult<()> {
sqlx::query!(
r#"
UPDATE af_collab
@ -140,6 +329,171 @@ impl CollabDiskCache {
)
.execute(&self.pg_pool)
.await?;
let key = collab_key(workspace_id, object_id);
match self.s3.delete_blob(&key).await {
Ok(_) | Err(AppError::RecordNotFound(_)) => Ok(()),
Err(err) => Err(err),
}
}
async fn insert_blob_with_retries(
s3: AwsS3BucketClientImpl,
key: String,
blob: Bytes,
mut retries: usize,
) -> Result<(), AppError> {
let doc_state = Self::compress_encoded_collab(blob)?;
while let Err(err) = s3.put_blob(&key, doc_state.clone().into(), None).await {
match err {
AppError::ServiceTemporaryUnavailable(err) if retries > 0 => {
tracing::info!(
"S3 service is temporarily unavailable: {}. Remaining retries: {}",
err,
retries
);
retries -= 1;
sleep(Duration::from_secs(5)).await;
},
err => {
tracing::error!("Failed to save collab to S3: {}", err);
break;
},
}
}
Ok(())
}
fn compress_encoded_collab(encoded_collab_v1: Bytes) -> Result<Bytes, AppError> {
let encoded_collab = EncodedCollab::decode_from_bytes(&encoded_collab_v1)
.map_err(|err| AppError::Internal(err.into()))?;
let now = Instant::now();
let doc_state = zstd::encode_all(&*encoded_collab.doc_state, ZSTD_COMPRESSION_LEVEL)?;
tracing::trace!(
"compressed collab {}B -> {}B in {:?}",
encoded_collab_v1.len(),
doc_state.len(),
now.elapsed()
);
Ok(doc_state.into())
}
}
async fn batch_put_collab_to_s3(
s3: &AwsS3BucketClientImpl,
collabs: HashMap<String, Bytes>,
) -> Result<(), AppError> {
let mut join_set = JoinSet::<Result<(), AppError>>::new();
let mut i = 0;
for (key, blob) in collabs {
let s3 = s3.clone();
join_set.spawn(async move {
let compressed = CollabDiskCache::compress_encoded_collab(blob)?;
s3.put_blob(&key, compressed.into(), None).await?;
Ok(())
});
i += 1;
if i % 500 == 0 {
while let Some(result) = join_set.join_next().await {
result.map_err(|err| AppError::Internal(err.into()))??;
}
}
}
while let Some(result) = join_set.join_next().await {
result.map_err(|err| AppError::Internal(err.into()))??;
}
Ok(())
}
async fn batch_get_collab_from_s3(
s3: &AwsS3BucketClientImpl,
workspace_id: &str,
params: Vec<QueryCollab>,
results: &mut HashMap<String, QueryCollabResult>,
) -> Vec<QueryCollab> {
enum GetResult {
Found(String, Vec<u8>),
NotFound(QueryCollab),
Error(String, String),
}
async fn gather(
join_set: &mut JoinSet<GetResult>,
results: &mut HashMap<String, QueryCollabResult>,
not_found: &mut Vec<QueryCollab>,
) {
while let Some(result) = join_set.join_next().await {
let now = Instant::now();
match result {
Ok(GetResult::Found(object_id, compressed)) => match zstd::decode_all(&*compressed) {
Ok(decompressed) => {
tracing::trace!(
"decompressed collab {}B -> {}B in {:?}",
compressed.len(),
decompressed.len(),
now.elapsed()
);
let encoded_collab = EncodedCollab {
state_vector: Default::default(),
doc_state: decompressed.into(),
version: EncoderVersion::V1,
};
results.insert(
object_id,
QueryCollabResult::Success {
encode_collab_v1: encoded_collab.encode_to_bytes().unwrap(),
},
);
},
Err(err) => {
results.insert(
object_id,
QueryCollabResult::Failed {
error: err.to_string(),
},
);
},
},
Ok(GetResult::NotFound(query)) => not_found.push(query),
Ok(GetResult::Error(object_id, error)) => {
results.insert(object_id, QueryCollabResult::Failed { error });
},
Err(err) => error!("Failed to get collab from S3: {}", err),
}
}
}
let mut not_found = Vec::new();
let mut i = 0;
let mut join_set = JoinSet::new();
for query in params {
let key = collab_key(workspace_id, &query.object_id);
let s3 = s3.clone();
join_set.spawn(async move {
match s3.get_blob(&key).await {
Ok(resp) => GetResult::Found(query.object_id, resp.to_blob()),
Err(AppError::RecordNotFound(_)) => GetResult::NotFound(query),
Err(err) => GetResult::Error(query.object_id, err.to_string()),
}
});
i += 1;
if i % 500 == 0 {
gather(&mut join_set, results, &mut not_found).await;
}
}
// gather remaining results from the last chunk
gather(&mut join_set, results, &mut not_found).await;
not_found
}
fn collab_key_prefix(workspace_id: &str, object_id: &str) -> String {
format!("collabs/{}/{}/", workspace_id, object_id)
}
fn collab_key(workspace_id: &str, object_id: &str) -> String {
format!(
"collabs/{}/{}/encoded_collab.v1.zstd",
workspace_id, object_id
)
}

View file

@ -60,6 +60,8 @@ pub trait BucketClient {
) -> Result<(usize, String), AppError>;
async fn remove_dir(&self, dir: &str) -> Result<(), AppError>;
async fn list_dir(&self, dir: &str, limit: usize) -> Result<Vec<String>, AppError>;
}
pub trait BlobKey: Send + Sync {
@ -90,7 +92,7 @@ where
#[instrument(skip_all, err)]
#[inline]
pub async fn put_blob<K: BlobKey>(
pub async fn put_blob_with_content_type<K: BlobKey>(
&self,
key: K,
file_stream: ByteStream,

View file

@ -102,7 +102,7 @@ impl AwsS3BucketClientImpl {
.await
.map_err(|e| AppError::Internal(anyhow::anyhow!(e)))?;
let content_length = head_object_result
let content_len = head_object_result
.content_length()
.ok_or_else(|| AppError::Unhandled("Content-Length not found".to_string()))?;
let content_type = head_object_result
@ -110,7 +110,13 @@ impl AwsS3BucketClientImpl {
.map(|s| s.to_string())
.unwrap_or_else(|| "application/octet-stream".to_string());
Ok((content_length as usize, content_type))
trace!(
"completed upload to S3: {} ({} bytes)",
object_key,
content_len
);
Ok((content_len as usize, content_type))
}
}
@ -140,6 +146,8 @@ impl BucketClient for AwsS3BucketClientImpl {
_ => AppError::Internal(anyhow!("Failed to upload object to S3: {}", err)),
})?;
trace!("put object to S3: {}", object_key);
Ok(())
}
@ -149,11 +157,6 @@ impl BucketClient for AwsS3BucketClientImpl {
stream: ByteStream,
content_type: &str,
) -> Result<(), AppError> {
trace!(
"Uploading object to S3 bucket:{}, key {}",
self.bucket,
object_key,
);
self
.client
.put_object()
@ -170,6 +173,8 @@ impl BucketClient for AwsS3BucketClientImpl {
_ => AppError::Internal(anyhow!("Failed to upload object to S3: {}", err)),
})?;
trace!("put object to S3: {} ({})", object_key, content_type);
Ok(())
}
@ -183,13 +188,15 @@ impl BucketClient for AwsS3BucketClientImpl {
.await
.map_err(|err| anyhow!("Failed to delete object to S3: {}", err))?;
trace!("deleted object from S3: {}", object_key);
Ok(S3ResponseData::from(output))
}
async fn delete_blobs(&self, object_keys: Vec<String>) -> Result<Self::ResponseData, AppError> {
let mut delete_object_ids: Vec<aws_sdk_s3::types::ObjectIdentifier> = vec![];
let mut delete_object_ids: Vec<ObjectIdentifier> = vec![];
for obj in object_keys {
let obj_id = aws_sdk_s3::types::ObjectIdentifier::builder()
let obj_id = ObjectIdentifier::builder()
.key(obj)
.build()
.map_err(|err| {
@ -198,6 +205,7 @@ impl BucketClient for AwsS3BucketClientImpl {
delete_object_ids.push(obj_id);
}
let len = delete_object_ids.len();
let output = self
.client
.delete_objects()
@ -214,6 +222,8 @@ impl BucketClient for AwsS3BucketClientImpl {
.await
.map_err(|err| anyhow!("Failed to delete objects from S3: {}", err))?;
trace!("deleted {} objects from S3", len);
Ok(S3ResponseData::from(output))
}
@ -229,6 +239,9 @@ impl BucketClient for AwsS3BucketClientImpl {
Ok(output) => match output.body.collect().await {
Ok(body) => {
let data = body.into_bytes().to_vec();
trace!("get object from S3: {} ({} bytes)", object_key, data.len());
Ok(S3ResponseData::new_with_data(data, output.content_type))
},
Err(err) => Err(AppError::from(anyhow!("Failed to collect body: {}", err))),
@ -256,12 +269,8 @@ impl BucketClient for AwsS3BucketClientImpl {
object_key: &str,
req: CreateUploadRequest,
) -> Result<CreateUploadResponse, AppError> {
trace!(
"Creating upload to S3 bucket:{}, key {}, request: {}",
self.bucket,
object_key,
req
);
trace!("creating multi-part upload to S3: {} - {}", object_key, req);
let multipart_upload_res = self
.client
.create_multipart_upload()
@ -289,12 +298,7 @@ impl BucketClient for AwsS3BucketClientImpl {
if req.body.is_empty() {
return Err(AppError::InvalidRequest("body is empty".to_string()));
}
trace!(
"Uploading part to S3 bucket:{}, key {}, request: {}",
self.bucket,
object_key,
req,
);
trace!("multi-part upload to s3: {} - {}", object_key, req,);
let body = ByteStream::from(req.body);
let upload_part_res = self
.client
@ -323,12 +327,6 @@ impl BucketClient for AwsS3BucketClientImpl {
object_key: &str,
req: CompleteUploadRequest,
) -> Result<(usize, String), AppError> {
trace!(
"Completing upload to S3 bucket:{}, key {}, request: {}",
self.bucket,
object_key,
req,
);
let parts = req
.parts
.into_iter()
@ -380,7 +378,7 @@ impl BucketClient for AwsS3BucketClientImpl {
.collect();
trace!(
"objects_to_delete: {:?} at directory: {}",
"deleting {} objects at directory: {}",
objects_to_delete.len(),
parent_dir
);
@ -393,15 +391,6 @@ impl BucketClient for AwsS3BucketClientImpl {
Vec::new()
};
trace!(
"Deleting {} objects: {:?}",
parent_dir,
objects_to_delete
.iter()
.map(|object| &object.key)
.collect::<Vec<&String>>()
);
let delete = Delete::builder()
.set_objects(Some(objects_to_delete))
.build()
@ -444,6 +433,27 @@ impl BucketClient for AwsS3BucketClientImpl {
Ok(())
}
async fn list_dir(&self, dir: &str, limit: usize) -> Result<Vec<String>, AppError> {
let list_objects = self
.client
.list_objects_v2()
.bucket(&self.bucket)
.prefix(dir)
.max_keys(limit as i32)
.send()
.await
.map_err(|err| anyhow!("Failed to list object: {}", err))?;
Ok(
list_objects
.contents
.unwrap_or_default()
.into_iter()
.filter_map(|o| o.key)
.collect(),
)
}
}
#[derive(Debug)]

View file

@ -89,7 +89,11 @@ validator = "0.16.1"
rayon.workspace = true
tiktoken-rs = "0.6.0"
unicode-segmentation = "1.9.0"
aws-sdk-s3 = { version = "1.36.0", features = [
"behavior-version-latest",
"rt-tokio",
] }
zstd.workspace = true
[dev-dependencies]
rand = "0.8.5"

View file

@ -10,6 +10,11 @@ use actix_web::dev::Server;
use actix_web::web::Data;
use actix_web::{App, HttpServer};
use anyhow::{Context, Error};
use aws_sdk_s3::config::{Credentials, Region, SharedCredentialsProvider};
use aws_sdk_s3::operation::create_bucket::CreateBucketError;
use aws_sdk_s3::types::{
BucketInfo, BucketLocationConstraint, BucketType, CreateBucketConfiguration,
};
use database::collab::cache::CollabCache;
use secrecy::ExposeSecret;
use sqlx::postgres::PgPoolOptions;
@ -17,15 +22,15 @@ use sqlx::PgPool;
use tracing::info;
use crate::actix_ws::server::RealtimeServerActor;
use crate::api::{collab_scope, ws_scope};
use crate::collab::access_control::CollabStorageAccessControlImpl;
use access_control::casbin::access::AccessControl;
use appflowy_ai_client::client::AppFlowyAIClient;
use crate::api::{collab_scope, ws_scope};
use database::file::s3_client_impl::AwsS3BucketClientImpl;
use crate::collab::storage::CollabStorageImpl;
use crate::command::{CLCommandReceiver, CLCommandSender};
use crate::config::{Config, DatabaseSetting};
use crate::config::{Config, DatabaseSetting, S3Setting};
use crate::indexer::IndexerProvider;
use crate::pg_listener::PgListeners;
use crate::snapshot::SnapshotControl;
@ -113,9 +118,19 @@ pub async fn init_state(config: &Config, rt_cmd_tx: CLCommandSender) -> Result<A
let access_control =
AccessControl::new(pg_pool.clone(), metrics.access_control_metrics.clone()).await?;
info!("Setting up S3 bucket...");
let s3_client = AwsS3BucketClientImpl::new(
get_aws_s3_client(&config.s3).await?,
config.s3.bucket.clone(),
);
let collab_access_control = CollabAccessControlImpl::new(access_control.clone());
let workspace_access_control = WorkspaceAccessControlImpl::new(access_control.clone());
let collab_cache = CollabCache::new(redis_conn_manager.clone(), pg_pool.clone());
let collab_cache = CollabCache::new(
redis_conn_manager.clone(),
pg_pool.clone(),
s3_client.clone(),
);
let collab_storage_access_control = CollabStorageAccessControlImpl {
collab_access_control: Arc::new(collab_access_control.clone()),
@ -123,8 +138,8 @@ pub async fn init_state(config: &Config, rt_cmd_tx: CLCommandSender) -> Result<A
cache: collab_cache.clone(),
};
let snapshot_control = SnapshotControl::new(
redis_conn_manager.clone(),
pg_pool.clone(),
s3_client.clone(),
metrics.collab_metrics.clone(),
)
.await;
@ -169,3 +184,82 @@ async fn get_connection_pool(setting: &DatabaseSetting) -> Result<PgPool, Error>
.await
.map_err(|e| anyhow::anyhow!("Failed to connect to postgres database: {}", e))
}
pub async fn get_aws_s3_client(s3_setting: &S3Setting) -> Result<aws_sdk_s3::Client, Error> {
let credentials = Credentials::new(
s3_setting.access_key.clone(),
s3_setting.secret_key.expose_secret().clone(),
None,
None,
"custom",
);
let shared_credentials = SharedCredentialsProvider::new(credentials);
// Configure the AWS SDK
let config_builder = aws_sdk_s3::Config::builder()
.credentials_provider(shared_credentials)
.force_path_style(true)
.region(Region::new(s3_setting.region.clone()));
let config = if s3_setting.use_minio {
config_builder.endpoint_url(&s3_setting.minio_url).build()
} else {
config_builder.build()
};
let client = aws_sdk_s3::Client::from_conf(config);
if s3_setting.create_bucket {
create_bucket_if_not_exists(&client, s3_setting).await?;
} else {
info!("Skipping bucket creation, assumed to be created externally");
}
Ok(client)
}
async fn create_bucket_if_not_exists(
client: &aws_sdk_s3::Client,
s3_setting: &S3Setting,
) -> Result<(), Error> {
let bucket_cfg = if s3_setting.use_minio {
CreateBucketConfiguration::builder()
.bucket(BucketInfo::builder().r#type(BucketType::Directory).build())
.build()
} else {
CreateBucketConfiguration::builder()
.location_constraint(BucketLocationConstraint::from(s3_setting.region.as_str()))
.build()
};
match client
.create_bucket()
.bucket(&s3_setting.bucket)
.create_bucket_configuration(bucket_cfg)
.send()
.await
{
Ok(_) => {
info!(
"bucket created successfully: {}, region: {}",
s3_setting.bucket, s3_setting.region
);
Ok(())
},
Err(err) => {
if let Some(service_error) = err.as_service_error() {
match service_error {
CreateBucketError::BucketAlreadyOwnedByYou(_)
| CreateBucketError::BucketAlreadyExists(_) => {
info!("Bucket already exists");
Ok(())
},
_ => {
tracing::error!("Unhandle s3 service error: {:?}", err);
Err(err.into())
},
}
} else {
tracing::error!("Failed to create bucket: {:?}", err);
Ok(())
}
},
}
}

View file

@ -36,7 +36,7 @@ impl CollabStorageAccessControl for CollabStorageAccessControlImpl {
uid: &i64,
oid: &str,
) -> Result<(), AppError> {
let collab_exists = self.cache.is_exist(oid).await?;
let collab_exists = self.cache.is_exist(workspace_id, oid).await?;
if !collab_exists {
// If the collab does not exist, we should not enforce the access control. We consider the user
// has the permission to read the collab
@ -54,7 +54,7 @@ impl CollabStorageAccessControl for CollabStorageAccessControlImpl {
uid: &i64,
oid: &str,
) -> Result<(), AppError> {
let collab_exists = self.cache.is_exist(oid).await?;
let collab_exists = self.cache.is_exist(workspace_id, oid).await?;
if !collab_exists {
// If the collab does not exist, we should not enforce the access control. we consider the user
// has the permission to write the collab

View file

@ -27,8 +27,8 @@ use database::collab::{
CollabStorageAccessControl, GetCollabOrigin,
};
use database_entity::dto::{
AFAccessLevel, AFSnapshotMeta, AFSnapshotMetas, CollabParams, InsertSnapshotParams, QueryCollab,
QueryCollabParams, QueryCollabResult, SnapshotData,
AFAccessLevel, AFSnapshotMeta, AFSnapshotMetas, CollabParams, InsertSnapshotParams,
PendingCollabWrite, QueryCollab, QueryCollabParams, QueryCollabResult, SnapshotData,
};
use crate::collab::access_control::CollabStorageAccessControlImpl;
@ -38,22 +38,6 @@ use crate::snapshot::SnapshotControl;
pub type CollabAccessControlStorage = CollabStorageImpl<CollabStorageAccessControlImpl>;
struct PendingCollabWrite {
workspace_id: String,
uid: i64,
params: CollabParams,
}
impl PendingCollabWrite {
fn new(workspace_id: String, uid: i64, params: CollabParams) -> Self {
PendingCollabWrite {
workspace_id,
uid,
params,
}
}
}
/// A wrapper around the actual storage implementation that provides access control and caching.
#[derive(Clone)]
pub struct CollabStorageImpl<AC> {
@ -119,57 +103,12 @@ where
metrics: &CollabMetrics,
records: impl ExactSizeIterator<Item = PendingCollabWrite>,
) -> Result<(), AppError> {
// Start a database transaction
let mut transaction = cache
.pg_pool()
.begin()
.await
.context("Failed to acquire transaction for writing pending collaboration data")
.map_err(AppError::from)?;
let total_records = records.len();
let mut successful_writes = 0;
// Insert each record into the database within the transaction context
let mut action_description = String::new();
for (index, record) in records.into_iter().enumerate() {
let params = record.params;
action_description = format!("{}", params);
let savepoint_name = format!("sp_{}", index);
// using savepoint to rollback the transaction if the insert fails
sqlx::query(&format!("SAVEPOINT {}", savepoint_name))
.execute(transaction.deref_mut())
.await?;
if let Err(_err) = cache
.insert_encode_collab_to_disk(&record.workspace_id, &record.uid, params, &mut transaction)
.await
{
sqlx::query(&format!("ROLLBACK TO SAVEPOINT {}", savepoint_name))
.execute(transaction.deref_mut())
.await?;
} else {
successful_writes += 1;
}
}
let successful_writes = cache.batch_insert_collab(records.collect()).await?;
metrics.record_write_collab(successful_writes, total_records as _);
// Commit the transaction to finalize all writes
match tokio::time::timeout(Duration::from_secs(10), transaction.commit()).await {
Ok(result) => {
result.map_err(AppError::from)?;
Ok(())
},
Err(_) => {
error!(
"Timeout waiting for committing the transaction for pending write:{}",
action_description
);
Err(AppError::Internal(anyhow!(
"Timeout when committing the transaction for pending collaboration data"
)))
},
}
Ok(())
}
async fn insert_collab(
@ -178,25 +117,10 @@ where
uid: &i64,
params: CollabParams,
) -> AppResult<()> {
// Start a database transaction
let mut transaction = self
.cache
.pg_pool()
.begin()
.await
.context("Failed to acquire transaction for writing pending collaboration data")
.map_err(AppError::from)?;
self
.cache
.insert_encode_collab_to_disk(workspace_id, uid, params, &mut transaction)
.insert_encode_collab_to_disk(workspace_id, uid, params)
.await?;
tokio::time::timeout(Duration::from_secs(10), transaction.commit())
.await
.map_err(|_| {
AppError::Internal(anyhow!(
"Timeout when committing the transaction for pending collaboration data"
))
})??;
Ok(())
}
@ -331,18 +255,10 @@ where
uid: &i64,
params_list: Vec<CollabParams>,
) -> Result<(), AppError> {
let mut transaction = self.cache.pg_pool().begin().await?;
insert_into_af_collab_bulk_for_user(&mut transaction, uid, workspace_id, &params_list).await?;
transaction.commit().await?;
// update the mem cache without blocking the current task
let cache = self.cache.clone();
tokio::spawn(async move {
for params in params_list {
let _ = cache.insert_encode_collab_to_mem(&params).await;
}
});
Ok(())
self
.cache
.bulk_insert_collab(workspace_id, uid, params_list)
.await
}
}
@ -364,7 +280,7 @@ where
write_immediately: bool,
) -> AppResult<()> {
params.validate()?;
let is_exist = self.cache.is_exist(&params.object_id).await?;
let is_exist = self.cache.is_exist(workspace_id, &params.object_id).await?;
// If the collab already exists, check if the user has enough permissions to update collab
// Otherwise, check if the user has enough permissions to create collab.
if is_exist {
@ -448,7 +364,7 @@ where
Duration::from_secs(120),
self
.cache
.insert_encode_collab_data(workspace_id, uid, &params, transaction),
.insert_encode_collab_data(workspace_id, uid, params, transaction),
)
.await
{
@ -494,13 +410,52 @@ where
}
}
let encode_collab = self.cache.get_encode_collab(params.inner).await?;
let encode_collab = self
.cache
.get_encode_collab(&params.workspace_id, params.inner)
.await?;
Ok(encode_collab)
}
async fn broadcast_encode_collab(
&self,
object_id: String,
collab_messages: Vec<ClientCollabMessage>,
) -> Result<(), AppError> {
let (sender, recv) = tokio::sync::oneshot::channel();
self
.rt_cmd_sender
.send(CollaborationCommand::ServerSendCollabMessage {
object_id,
collab_messages,
ret: sender,
})
.await
.map_err(|err| {
AppError::Unhandled(format!(
"Failed to send encode collab command to realtime server: {}",
err
))
})?;
match recv.await {
Ok(res) =>
if let Err(err) = res {
error!("Failed to broadcast encode collab: {}", err);
}
,
// caller may have dropped the receiver
Err(err) => warn!("Failed to receive response from realtime server: {}", err),
}
Ok(())
}
async fn batch_get_collab(
&self,
_uid: &i64,
workspace_id: &str,
queries: Vec<QueryCollab>,
from_editing_collab: bool,
) -> HashMap<String, QueryCollabResult> {
@ -558,7 +513,12 @@ where
valid_queries
};
results.extend(self.cache.batch_get_encode_collab(cache_queries).await);
results.extend(
self
.cache
.batch_get_encode_collab(workspace_id, cache_queries)
.await,
);
results
}
@ -567,20 +527,15 @@ where
.access_control
.enforce_delete(workspace_id, uid, object_id)
.await?;
self.cache.delete_collab(object_id).await?;
self.cache.delete_collab(workspace_id, object_id).await?;
Ok(())
}
async fn query_collab_meta(
&self,
object_id: &str,
collab_type: &CollabType,
) -> AppResult<CollabMetadata> {
self.cache.get_collab_meta(object_id, collab_type).await
}
async fn should_create_snapshot(&self, oid: &str) -> Result<bool, AppError> {
self.snapshot_control.should_create_snapshot(oid).await
async fn should_create_snapshot(&self, workspace_id: &str, oid: &str) -> Result<bool, AppError> {
self
.snapshot_control
.should_create_snapshot(workspace_id, oid)
.await
}
async fn create_snapshot(&self, params: InsertSnapshotParams) -> AppResult<AFSnapshotMeta> {
@ -603,42 +558,14 @@ where
.await
}
async fn get_collab_snapshot_list(&self, oid: &str) -> AppResult<AFSnapshotMetas> {
self.snapshot_control.get_collab_snapshot_list(oid).await
}
async fn broadcast_encode_collab(
async fn get_collab_snapshot_list(
&self,
object_id: String,
collab_messages: Vec<ClientCollabMessage>,
) -> Result<(), AppError> {
let (sender, recv) = tokio::sync::oneshot::channel();
workspace_id: &str,
oid: &str,
) -> AppResult<AFSnapshotMetas> {
self
.rt_cmd_sender
.send(CollaborationCommand::ServerSendCollabMessage {
object_id,
collab_messages,
ret: sender,
})
.snapshot_control
.get_collab_snapshot_list(workspace_id, oid)
.await
.map_err(|err| {
AppError::Unhandled(format!(
"Failed to send encode collab command to realtime server: {}",
err
))
})?;
match recv.await {
Ok(res) =>
if let Err(err) = res {
error!("Failed to broadcast encode collab: {}", err);
}
,
// caller may have dropped the receiver
Err(err) => warn!("Failed to receive response from realtime server: {}", err),
}
Ok(())
}
}

View file

@ -16,6 +16,18 @@ pub struct Config {
pub collab: CollabSetting,
pub redis_uri: Secret<String>,
pub ai: AISettings,
pub s3: S3Setting,
}
#[derive(serde::Deserialize, Clone, Debug)]
pub struct S3Setting {
pub create_bucket: bool,
pub use_minio: bool,
pub minio_url: String,
pub access_key: String,
pub secret_key: Secret<String>,
pub bucket: String,
pub region: String,
}
#[derive(Clone, Debug)]
@ -155,6 +167,19 @@ pub fn get_configuration() -> Result<Config, anyhow::Error> {
.parse()
.context("fail to get APPFLOWY_DATABASE_MAX_CONNECTIONS")?,
},
s3: S3Setting {
create_bucket: get_env_var("APPFLOWY_S3_CREATE_BUCKET", "true")
.parse()
.context("fail to get APPFLOWY_S3_CREATE_BUCKET")?,
use_minio: get_env_var("APPFLOWY_S3_USE_MINIO", "true")
.parse()
.context("fail to get APPFLOWY_S3_USE_MINIO")?,
minio_url: get_env_var("APPFLOWY_S3_MINIO_URL", "http://localhost:9000"),
access_key: get_env_var("APPFLOWY_S3_ACCESS_KEY", "minioadmin"),
secret_key: get_env_var("APPFLOWY_S3_SECRET_KEY", "minioadmin").into(),
bucket: get_env_var("APPFLOWY_S3_BUCKET", "appflowy"),
region: get_env_var("APPFLOWY_S3_REGION", ""),
},
gotrue: GoTrueSetting {
jwt_secret: get_env_var("APPFLOWY_GOTRUE_JWT_SECRET", "hello456").into(),
},

View file

@ -64,26 +64,18 @@ pub enum RealtimeError {
#[derive(Debug)]
pub enum CreateGroupFailedReason {
CollabWorkspaceIdNotMatch {
expect: String,
actual: String,
detail: String,
},
CollabWorkspaceIdNotMatch { expect: String, detail: String },
CannotGetCollabData,
}
impl Display for CreateGroupFailedReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CreateGroupFailedReason::CollabWorkspaceIdNotMatch {
expect,
actual,
detail,
} => {
CreateGroupFailedReason::CollabWorkspaceIdNotMatch { expect, detail } => {
write!(
f,
"Collab workspace id not match: expect {}, actual {}, detail: {}",
expect, actual, detail
"Collab workspace id not match: expect {}, detail: {}",
expect, detail
)
},
CreateGroupFailedReason::CannotGetCollabData => {

View file

@ -126,27 +126,6 @@ where
) -> Result<(), RealtimeError> {
let mut is_new_collab = false;
let params = QueryCollabParams::new(object_id, collab_type.clone(), workspace_id);
// Ensure the workspace_id matches the metadata's workspace_id when creating a collaboration object
// of type [CollabType::Folder]. In this case, both the object id and the workspace id should be
// identical.
if let Ok(metadata) = self
.storage
.query_collab_meta(object_id, &collab_type)
.await
{
if metadata.workspace_id != workspace_id {
let err =
RealtimeError::CreateGroupFailed(CreateGroupFailedReason::CollabWorkspaceIdNotMatch {
expect: metadata.workspace_id,
actual: workspace_id.to_string(),
detail: format!(
"user_id:{},app_version:{},object_id:{}:{}",
user.uid, user.app_version, object_id, collab_type
),
});
return Err(err);
}
}
let result = load_collab(user.uid, object_id, params, self.storage.clone()).await;
let (collab, _encode_collab) = {
@ -251,7 +230,7 @@ where
let encode_collab = get_latest_snapshot(
&params.workspace_id,
object_id,
&storage,
&*storage,
&params.collab_type,
)
.await?;
@ -275,7 +254,11 @@ async fn get_latest_snapshot<S>(
where
S: CollabStorage,
{
let metas = storage.get_collab_snapshot_list(object_id).await.ok()?.0;
let metas = storage
.get_collab_snapshot_list(workspace_id, object_id)
.await
.ok()?
.0;
for meta in metas {
let snapshot_data = storage
.get_collab_snapshot(workspace_id, &meta.object_id, &meta.snapshot_id)

View file

@ -61,10 +61,10 @@ where
let lock = collab.read().await;
let encode_collab =
lock.encode_collab_v1(|collab| collab_type.validate_require_data(collab))?;
let bytes = encode_collab.encode_to_bytes()?;
let data = encode_collab.doc_state;
let params = InsertSnapshotParams {
object_id,
encoded_collab_v1: bytes,
data,
workspace_id,
collab_type,
};
@ -102,7 +102,10 @@ where
tokio::spawn(async move {
sleep(std::time::Duration::from_secs(2)).await;
match storage.should_create_snapshot(&object_id).await {
match storage
.should_create_snapshot(&workspace_id, &object_id)
.await
{
Ok(true) => {
if let Err(err) =
Self::enqueue_snapshot(weak_collab, storage, workspace_id, object_id, collab_type).await

View file

@ -144,8 +144,9 @@ where
#[derive(Clone)]
pub struct CollabMetrics {
success_write_snapshot_count: Gauge,
total_write_snapshot_count: Gauge,
pub write_snapshot: Counter,
pub write_snapshot_failures: Counter,
pub read_snapshot: Counter,
success_write_collab_count: Counter,
total_write_collab_count: Counter,
success_queue_collab_count: Counter,
@ -154,8 +155,9 @@ pub struct CollabMetrics {
impl CollabMetrics {
fn init() -> Self {
Self {
success_write_snapshot_count: Gauge::default(),
total_write_snapshot_count: Default::default(),
write_snapshot: Default::default(),
write_snapshot_failures: Default::default(),
read_snapshot: Default::default(),
success_write_collab_count: Default::default(),
total_write_collab_count: Default::default(),
success_queue_collab_count: Default::default(),
@ -166,14 +168,19 @@ impl CollabMetrics {
let metrics = Self::init();
let realtime_registry = registry.sub_registry_with_prefix("collab");
realtime_registry.register(
"success_write_snapshot_count",
"success write snapshot to db",
metrics.success_write_snapshot_count.clone(),
"write_snapshot",
"snapshot write attempts counter",
metrics.write_snapshot.clone(),
);
realtime_registry.register(
"total_attempt_write_snapshot_count",
"total attempt write snapshot to db",
metrics.total_write_snapshot_count.clone(),
"write_snapshot_failures",
"counter for failed attempts to write a snapshot",
metrics.write_snapshot_failures.clone(),
);
realtime_registry.register(
"read_snapshot",
"snapshot read counter",
metrics.read_snapshot.clone(),
);
realtime_registry.register(
"success_write_collab_count",
@ -194,11 +201,6 @@ impl CollabMetrics {
metrics
}
pub fn record_write_snapshot(&self, success_attempt: i64, total_attempt: i64) {
self.success_write_snapshot_count.set(success_attempt);
self.total_write_snapshot_count.set(total_attempt);
}
pub fn record_write_collab(&self, success_attempt: u64, total_attempt: u64) {
self.success_write_collab_count.inc_by(success_attempt);
self.total_write_collab_count.inc_by(total_attempt);

View file

@ -1,61 +0,0 @@
use std::sync::Arc;
use anyhow::anyhow;
use collab::lock::Mutex;
use redis::AsyncCommands;
use app_error::AppError;
use crate::state::RedisConnectionManager;
#[derive(Clone)]
pub(crate) struct SnapshotCache {
redis_client: Arc<Mutex<RedisConnectionManager>>,
}
impl SnapshotCache {
pub fn new(redis_client: Arc<Mutex<RedisConnectionManager>>) -> Self {
Self { redis_client }
}
/// Returns all existing keys start with `prefix`
#[allow(dead_code)]
pub async fn keys(&self, prefix: &str) -> Result<Vec<String>, AppError> {
let mut redis = self.redis_client.lock().await;
let keys: Vec<String> = redis
.keys(format!("{}*", prefix))
.await
.map_err(|err| AppError::Internal(err.into()))?;
Ok(keys)
}
pub async fn insert(&self, key: &str, value: Vec<u8>) -> Result<(), AppError> {
let mut redis = self.redis_client.lock().await;
redis
.set_ex(key, value, 60 * 60 * 24)
.await
.map_err(|err| AppError::Internal(err.into()))?;
Ok(())
}
pub async fn try_get(&self, key: &str) -> Result<Option<Vec<u8>>, AppError> {
let mut redis = self
.redis_client
.try_lock()
.map_err(|_| AppError::Internal(anyhow!("lock error")))?;
let value = redis
.get::<_, Option<Vec<u8>>>(key)
.await
.map_err(|err| AppError::Internal(err.into()))?;
Ok(value)
}
pub async fn remove(&self, key: &str) -> Result<(), AppError> {
let mut redis = self.redis_client.lock().await;
redis
.del(key)
.await
.map_err(|err| AppError::Internal(err.into()))?;
Ok(())
}
}

View file

@ -1,5 +1,3 @@
mod cache;
mod queue;
mod snapshot_control;
pub use snapshot_control::*;

View file

@ -1,83 +0,0 @@
use collab_entity::CollabType;
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::AtomicI64;
pub(crate) struct PendingQueue {
id_gen: AtomicI64,
queue: BinaryHeap<PendingItem>,
}
impl PendingQueue {
pub(crate) fn new() -> Self {
Self {
id_gen: Default::default(),
queue: Default::default(),
}
}
pub(crate) fn generate_item(
&mut self,
workspace_id: String,
object_id: String,
collab_type: CollabType,
) -> PendingItem {
let seq = self
.id_gen
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
PendingItem {
workspace_id,
object_id,
seq,
collab_type,
}
}
pub(crate) fn push_item(&mut self, item: PendingItem) {
self.queue.push(item);
}
}
impl Deref for PendingQueue {
type Target = BinaryHeap<PendingItem>;
fn deref(&self) -> &Self::Target {
&self.queue
}
}
impl DerefMut for PendingQueue {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.queue
}
}
#[derive(Debug, Clone)]
pub(crate) struct PendingItem {
pub(crate) workspace_id: String,
pub(crate) object_id: String,
pub(crate) seq: i64,
pub(crate) collab_type: CollabType,
}
impl PartialEq<Self> for PendingItem {
fn eq(&self, other: &Self) -> bool {
self.object_id == other.object_id && self.seq == other.seq
}
}
impl Eq for PendingItem {}
impl PartialOrd for PendingItem {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PendingItem {
fn cmp(&self, other: &Self) -> Ordering {
// smaller seq is higher priority
self.seq.cmp(&other.seq).reverse()
}
}

View file

@ -1,94 +1,96 @@
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use anyhow::anyhow;
use async_stream::stream;
use chrono::{DateTime, Utc};
use collab::lock::{Mutex, RwLock};
use futures_util::StreamExt;
use collab::entity::{EncodedCollab, EncoderVersion};
use sqlx::PgPool;
use tokio::time::interval;
use tracing::{debug, error, trace, warn};
use validator::Validate;
use app_error::AppError;
use collab_rt_protocol::spawn_blocking_validate_encode_collab;
use database::collab::{
create_snapshot_and_maintain_limit, get_all_collab_snapshot_meta, latest_snapshot_time,
select_snapshot, AppResult, COLLAB_SNAPSHOT_LIMIT, SNAPSHOT_PER_HOUR,
get_all_collab_snapshot_meta, latest_snapshot_time, select_snapshot, AppResult,
COLLAB_SNAPSHOT_LIMIT, SNAPSHOT_PER_HOUR,
};
use database::file::s3_client_impl::AwsS3BucketClientImpl;
use database::file::{BucketClient, ResponseBlob};
use database_entity::dto::{
AFSnapshotMeta, AFSnapshotMetas, InsertSnapshotParams, SnapshotData, ZSTD_COMPRESSION_LEVEL,
};
use database_entity::dto::{AFSnapshotMeta, AFSnapshotMetas, InsertSnapshotParams, SnapshotData};
use crate::metrics::CollabMetrics;
use crate::snapshot::cache::SnapshotCache;
use crate::snapshot::queue::PendingQueue;
use crate::state::RedisConnectionManager;
pub type SnapshotCommandReceiver = tokio::sync::mpsc::Receiver<SnapshotCommand>;
pub type SnapshotCommandSender = tokio::sync::mpsc::Sender<SnapshotCommand>;
pub const SNAPSHOT_TICK_INTERVAL: Duration = Duration::from_secs(2);
pub enum SnapshotCommand {
InsertSnapshot(InsertSnapshotParams),
Tick(tokio::sync::oneshot::Sender<SnapshotMetric>),
fn collab_snapshot_key(workspace_id: &str, object_id: &str, snapshot_id: i64) -> String {
let snapshot_id = u64::MAX - snapshot_id as u64;
format!(
"collabs/{}/{}/snapshot_{:16x}.v1.zstd",
workspace_id, object_id, snapshot_id
)
}
fn collab_snapshot_prefix(workspace_id: &str, object_id: &str) -> String {
format!("collabs/{}/{}/snapshot_", workspace_id, object_id)
}
fn get_timestamp(object_key: &str) -> Option<DateTime<Utc>> {
let (_, right) = object_key.rsplit_once('/')?;
let trimmed = right
.trim_start_matches("snapshot_")
.trim_end_matches(".v1.zstd");
let snapshot_id = u64::from_str_radix(trimmed, 16).ok()?;
let snapshot_id = u64::MAX - snapshot_id;
DateTime::from_timestamp_millis(snapshot_id as i64)
}
fn get_meta(objct_key: String) -> Option<AFSnapshotMeta> {
let (left, right) = objct_key.rsplit_once('/')?;
let (_, object_id) = left.rsplit_once('/')?;
let trimmed = right
.trim_start_matches("snapshot_")
.trim_end_matches(".v1.zstd");
let snapshot_id = u64::from_str_radix(trimmed, 16).ok()?;
let snapshot_id = u64::MAX - snapshot_id;
Some(AFSnapshotMeta {
snapshot_id: snapshot_id as i64,
object_id: object_id.to_string(),
created_at: DateTime::from_timestamp_millis(snapshot_id as i64)?,
})
}
#[derive(Clone)]
// #[deprecated(note = "snapshot is implemented in the appflowy-history")]
pub struct SnapshotControl {
cache: SnapshotCache,
command_sender: SnapshotCommandSender,
pg_pool: PgPool,
s3: AwsS3BucketClientImpl,
collab_metrics: Arc<CollabMetrics>,
}
impl SnapshotControl {
pub async fn new(
redis_client: RedisConnectionManager,
pg_pool: PgPool,
s3: AwsS3BucketClientImpl,
collab_metrics: Arc<CollabMetrics>,
) -> Self {
let redis_client = Arc::new(Mutex::from(redis_client));
let (command_sender, rx) = tokio::sync::mpsc::channel(2000);
let cache = SnapshotCache::new(redis_client);
let runner = SnapshotCommandRunner::new(pg_pool.clone(), cache.clone(), rx);
tokio::spawn(runner.run());
let cloned_sender = command_sender.clone();
tokio::spawn(async move {
let mut interval = interval(SNAPSHOT_TICK_INTERVAL);
loop {
interval.tick().await;
let (tx, rx) = tokio::sync::oneshot::channel();
if let Err(err) = cloned_sender.send(SnapshotCommand::Tick(tx)).await {
error!("Failed to send tick command: {}", err);
}
if let Ok(metric) = rx.await {
collab_metrics.record_write_snapshot(
metric.success_write_snapshot_count,
metric.total_write_snapshot_count,
);
}
}
});
Self {
cache,
command_sender,
pg_pool,
s3,
collab_metrics,
}
}
pub async fn should_create_snapshot(&self, oid: &str) -> Result<bool, AppError> {
pub async fn should_create_snapshot(
&self,
workspace_id: &str,
oid: &str,
) -> Result<bool, AppError> {
if oid.is_empty() {
warn!("unexpected empty object id when checking should_create_snapshot");
return Ok(false);
}
let latest_created_at = self.latest_snapshot_time(oid).await?;
let latest_created_at = self.latest_snapshot_time(workspace_id, oid).await?;
// Subtracting a fixed duration that is known not to cause underflow. If `checked_sub_signed` returns `None`,
// it indicates an error in calculation, thus defaulting to creating a snapshot just in case.
let threshold_time = Utc::now().checked_sub_signed(chrono::Duration::hours(SNAPSHOT_PER_HOUR));
@ -112,53 +114,119 @@ impl SnapshotControl {
params.validate()?;
debug!("create snapshot for object:{}", params.object_id);
match self.pg_pool.try_begin().await {
Ok(Some(transaction)) => {
let meta = create_snapshot_and_maintain_limit(
transaction,
&params.workspace_id,
&params.object_id,
&params.encoded_collab_v1,
COLLAB_SNAPSHOT_LIMIT,
)
.await?;
Ok(meta)
},
_ => Err(AppError::Internal(anyhow!(
"fail to acquire transaction to create snapshot for object:{}",
params.object_id,
))),
self.collab_metrics.write_snapshot.inc();
let timestamp = Utc::now();
let snapshot_id = timestamp.timestamp_millis();
let key = collab_snapshot_key(&params.workspace_id, &params.object_id, snapshot_id);
let compressed = zstd::encode_all(params.data.as_ref(), ZSTD_COMPRESSION_LEVEL)?;
if let Err(err) = self.s3.put_blob(&key, compressed.into(), None).await {
self.collab_metrics.write_snapshot_failures.inc();
return Err(err);
}
// drop old snapshots if exceeds limit
let list = self
.s3
.list_dir(
&collab_snapshot_prefix(&params.workspace_id, &params.object_id),
100,
)
.await?;
if list.len() > COLLAB_SNAPSHOT_LIMIT as usize {
debug!(
"drop {} snapshots for `{}`",
list.len() - COLLAB_SNAPSHOT_LIMIT as usize,
params.object_id
);
let trimmed: Vec<_> = list
.into_iter()
.skip(COLLAB_SNAPSHOT_LIMIT as usize)
.collect();
self.s3.delete_blobs(trimmed).await?;
}
Ok(AFSnapshotMeta {
snapshot_id,
object_id: params.object_id,
created_at: timestamp,
})
}
pub async fn get_collab_snapshot(&self, snapshot_id: &i64) -> AppResult<SnapshotData> {
match select_snapshot(&self.pg_pool, snapshot_id).await? {
None => Err(AppError::RecordNotFound(format!(
"Can't find the snapshot with id:{}",
snapshot_id
))),
Some(row) => Ok(SnapshotData {
object_id: row.oid,
encoded_collab_v1: row.blob,
workspace_id: row.workspace_id.to_string(),
}),
pub async fn get_collab_snapshot(
&self,
workspace_id: &str,
object_id: &str,
snapshot_id: &i64,
) -> AppResult<SnapshotData> {
let key = collab_snapshot_key(workspace_id, object_id, *snapshot_id);
match self.s3.get_blob(&key).await {
Ok(resp) => {
self.collab_metrics.read_snapshot.inc();
let decompressed = zstd::decode_all(&*resp.to_blob())?;
let encoded_collab = EncodedCollab {
state_vector: Default::default(),
doc_state: decompressed.into(),
version: EncoderVersion::V1,
};
Ok(SnapshotData {
object_id: object_id.to_string(),
encoded_collab_v1: encoded_collab.encode_to_bytes()?,
workspace_id: workspace_id.to_string(),
})
},
Err(AppError::RecordNotFound(_)) => {
debug!(
"snapshot {} for `{}` not found in s3: fallback to postgres",
snapshot_id, object_id
);
match select_snapshot(&self.pg_pool, workspace_id, object_id, snapshot_id).await? {
None => Err(AppError::RecordNotFound(format!(
"Can't find the snapshot with id:{}",
snapshot_id
))),
Some(row) => Ok(SnapshotData {
object_id: object_id.to_string(),
encoded_collab_v1: row.blob,
workspace_id: workspace_id.to_string(),
}),
}
},
Err(err) => Err(err),
}
}
/// Returns list of snapshots for given object_id in descending order of creation time.
pub async fn get_collab_snapshot_list(&self, oid: &str) -> AppResult<AFSnapshotMetas> {
let metas = get_all_collab_snapshot_meta(&self.pg_pool, oid).await?;
Ok(metas)
pub async fn get_collab_snapshot_list(
&self,
workspace_id: &str,
oid: &str,
) -> AppResult<AFSnapshotMetas> {
let snapshot_prefix = collab_snapshot_prefix(workspace_id, oid);
let resp = self
.s3
.list_dir(&snapshot_prefix, COLLAB_SNAPSHOT_LIMIT as usize)
.await?;
if resp.is_empty() {
let metas = get_all_collab_snapshot_meta(&self.pg_pool, oid).await?;
Ok(metas)
} else {
let metas: Vec<_> = resp.into_iter().filter_map(get_meta).collect();
Ok(AFSnapshotMetas(metas))
}
}
pub async fn queue_snapshot(&self, params: InsertSnapshotParams) -> Result<(), AppError> {
params.validate()?;
trace!("Queuing snapshot for {}", params.object_id);
self
.command_sender
.send(SnapshotCommand::InsertSnapshot(params))
.await
.map_err(|err| AppError::Internal(err.into()))?;
let ctrl = self.clone();
tokio::spawn(async move {
if let Err(err) = ctrl.create_snapshot(params).await {
error!("Failed to create snapshot: {}", err);
}
});
Ok(())
}
@ -168,165 +236,22 @@ impl SnapshotControl {
object_id: &str,
snapshot_id: &i64,
) -> Result<SnapshotData, AppError> {
let key = SnapshotKey::from_object_id(object_id);
let encoded_collab_v1 = self.cache.try_get(&key.0).await.unwrap_or(None);
match encoded_collab_v1 {
None => self.get_collab_snapshot(snapshot_id).await,
Some(encoded_collab_v1) => Ok(SnapshotData {
encoded_collab_v1,
workspace_id: workspace_id.to_string(),
object_id: object_id.to_string(),
}),
}
self
.get_collab_snapshot(workspace_id, object_id, snapshot_id)
.await
}
async fn latest_snapshot_time(&self, oid: &str) -> Result<Option<DateTime<Utc>>, AppError> {
let time = latest_snapshot_time(oid, &self.pg_pool).await?;
Ok(time)
}
}
struct SnapshotCommandRunner {
pg_pool: PgPool,
queue: RwLock<PendingQueue>,
cache: SnapshotCache,
recv: Option<SnapshotCommandReceiver>,
success_attempts: AtomicU64,
total_attempts: AtomicU64,
}
impl SnapshotCommandRunner {
fn new(pg_pool: PgPool, cache: SnapshotCache, recv: SnapshotCommandReceiver) -> Self {
let queue = PendingQueue::new();
Self {
pg_pool,
queue: RwLock::from(queue),
cache,
recv: Some(recv),
success_attempts: Default::default(),
total_attempts: Default::default(),
}
}
async fn run(mut self) {
let mut receiver = self.recv.take().expect("Only take once");
let stream = stream! {
while let Some(cmd) = receiver.recv().await {
yield cmd;
}
};
stream
.for_each(|command| async {
self.handle_command(command).await;
})
.await;
}
async fn handle_command(&self, command: SnapshotCommand) {
match command {
SnapshotCommand::InsertSnapshot(params) => {
let mut queue = self.queue.write().await;
let item = queue.generate_item(params.workspace_id, params.object_id, params.collab_type);
let key = SnapshotKey::from_object_id(&item.object_id);
queue.push_item(item);
drop(queue);
if let Err(err) = self.cache.insert(&key.0, params.encoded_collab_v1).await {
error!("Failed to insert snapshot to cache: {}", err);
}
},
SnapshotCommand::Tick(tx) => {
if let Err(e) = self.process_next_batch().await {
error!("Failed to process next batch: {}", e);
}
let _ = tx.send(SnapshotMetric {
success_write_snapshot_count: self.success_attempts.load(Ordering::Relaxed) as i64,
total_write_snapshot_count: self.total_attempts.load(Ordering::Relaxed) as i64,
});
},
}
}
async fn process_next_batch(&self) -> Result<(), AppError> {
let mut queue = self.queue.write().await;
let next_item = match queue.pop() {
Some(item) => item,
None => return Ok(()),
};
self.total_attempts.fetch_add(1, Ordering::Relaxed);
let key = SnapshotKey::from_object_id(&next_item.object_id);
// Attempt to fetch the collab data from the cache
let encoded_collab_v1 = match self.cache.try_get(&key.0).await {
Ok(Some(data)) => data,
Ok(None) => return Ok(()), // Cache miss, no data to process
Err(_) => {
queue.push_item(next_item); // Push back to queue on error
return Ok(());
},
};
// Validate collab data before processing
let result = spawn_blocking_validate_encode_collab(
&next_item.object_id,
&encoded_collab_v1,
&next_item.collab_type,
)
.await;
if result.is_err() {
return Ok(());
}
// Start a transaction
let transaction = match self.pg_pool.try_begin().await {
Ok(Some(tx)) => tx,
_ => {
debug!("Failed to start transaction to write snapshot, retrying later");
queue.push_item(next_item);
return Ok(());
},
};
// Create the snapshot and enforce limits
match create_snapshot_and_maintain_limit(
transaction,
&next_item.workspace_id,
&next_item.object_id,
&encoded_collab_v1,
COLLAB_SNAPSHOT_LIMIT,
)
.await
{
Ok(_) => {
trace!(
"successfully created snapshot for {}, remaining task: {}",
next_item.object_id,
queue.len()
);
let _ = self.cache.remove(&key.0).await;
self.success_attempts.fetch_add(1, Ordering::Relaxed);
Ok(())
},
Err(e) => Err(e), // Return the error if snapshot creation fails
async fn latest_snapshot_time(
&self,
workspace_id: &str,
oid: &str,
) -> Result<Option<DateTime<Utc>>, AppError> {
let snapshot_prefix = collab_snapshot_prefix(workspace_id, oid);
let mut resp = self.s3.list_dir(&snapshot_prefix, 1).await?;
if let Some(key) = resp.pop() {
Ok(get_timestamp(&key))
} else {
Ok(latest_snapshot_time(oid, &self.pg_pool).await?)
}
}
}
const SNAPSHOT_PREFIX: &str = "full_snapshot";
struct SnapshotKey(String);
impl SnapshotKey {
fn from_object_id(object_id: &str) -> Self {
Self(format!("{}:{}", SNAPSHOT_PREFIX, object_id))
}
}
pub struct SnapshotMetric {
success_write_snapshot_count: i64,
total_write_snapshot_count: i64,
}

View file

@ -50,4 +50,5 @@ md5.workspace = true
base64.workspace = true
prometheus-client = "0.22.3"
reqwest = "0.12.5"
zstd.workspace = true

View file

@ -9,7 +9,7 @@ use crate::s3_client::S3Client;
use bytes::Bytes;
use collab::core::origin::CollabOrigin;
use collab::entity::EncodedCollab;
use collab::entity::{EncodedCollab, EncoderVersion};
use collab_database::workspace_database::WorkspaceDatabase;
use collab_entity::CollabType;
use collab_folder::{Folder, View, ViewLayout};
@ -33,7 +33,7 @@ use collab_importer::zip_tool::async_zip::async_unzip;
use collab_importer::zip_tool::sync_zip::sync_unzip;
use futures::stream::FuturesUnordered;
use futures::{stream, AsyncBufRead, StreamExt};
use futures::{stream, AsyncBufRead, AsyncReadExt, StreamExt};
use infra::env_util::get_env_var;
use redis::aio::ConnectionManager;
use redis::streams::{
@ -884,8 +884,14 @@ async fn process_unzip_file(
);
// 1. Open the workspace folder
let folder_collab =
get_encode_collab_from_bytes(&imported.workspace_id, &CollabType::Folder, pg_pool).await?;
let folder_collab = get_encode_collab_from_bytes(
&imported.workspace_id,
&imported.workspace_id,
&CollabType::Folder,
pg_pool,
s3_client,
)
.await?;
let mut folder = Folder::from_collab_doc_state(
import_task.uid,
CollabOrigin::Server,
@ -959,8 +965,14 @@ async fn process_unzip_file(
// 4. Edit workspace database collab and then encode workspace database collab
if !database_view_ids_by_database_id.is_empty() {
let w_db_collab =
get_encode_collab_from_bytes(&w_database_id, &CollabType::WorkspaceDatabase, pg_pool).await?;
let w_db_collab = get_encode_collab_from_bytes(
&import_task.workspace_id,
&w_database_id,
&CollabType::WorkspaceDatabase,
pg_pool,
s3_client,
)
.await?;
let mut w_database = WorkspaceDatabase::from_collab_doc_state(
&w_database_id,
CollabOrigin::Server,
@ -1310,22 +1322,41 @@ async fn upload_file_to_s3(
}
async fn get_encode_collab_from_bytes(
workspace_id: &str,
object_id: &str,
collab_type: &CollabType,
pg_pool: &PgPool,
s3: &Arc<dyn S3Client>,
) -> Result<EncodedCollab, ImportError> {
let bytes = select_blob_from_af_collab(pg_pool, collab_type, object_id)
.await
.map_err(|err| ImportError::Internal(err.into()))?;
tokio::task::spawn_blocking(move || match EncodedCollab::decode_from_bytes(&bytes) {
Ok(encoded_collab) => Ok(encoded_collab),
Err(err) => Err(ImportError::Internal(anyhow!(
"Failed to decode collab from bytes: {:?}",
err
))),
})
.await
.map_err(|err| ImportError::Internal(err.into()))?
let key = collab_key(workspace_id, object_id);
match s3.get_blob_stream(&key).await {
Ok(mut resp) => {
let mut buf = Vec::with_capacity(resp.content_length.unwrap_or(1024) as usize);
resp
.stream
.read_to_end(&mut buf)
.await
.map_err(|err| ImportError::Internal(err.into()))?;
let decompressed = zstd::decode_all(&*buf).map_err(|e| ImportError::Internal(e.into()))?;
Ok(EncodedCollab {
state_vector: Default::default(),
doc_state: decompressed.into(),
version: EncoderVersion::V1,
})
},
Err(WorkerError::RecordNotFound(_)) => {
// fallback to postgres
let bytes = select_blob_from_af_collab(pg_pool, collab_type, object_id)
.await
.map_err(|err| ImportError::Internal(err.into()))?;
Ok(
EncodedCollab::decode_from_bytes(&bytes)
.map_err(|err| ImportError::Internal(err.into()))?,
)
},
Err(err) => return Err(err.into()),
}
}
/// Ensure the consumer group exists, if not, create it.
@ -1546,3 +1577,10 @@ async fn insert_meta_from_path(
file_size,
})
}
fn collab_key(workspace_id: &str, object_id: &str) -> String {
format!(
"collabs/{}/{}/encoded_collab.v1.zstd",
workspace_id, object_id
)
}

View file

@ -89,6 +89,13 @@ impl S3Client for S3ClientImpl {
let stream = output.body.into_async_read().compat();
let content_type = output.content_type;
let content_length = output.content_length;
trace!(
"get object from S3: {} ({:?} bytes)",
object_key,
content_length
);
Ok(S3StreamResponse {
stream: Box::new(stream),
content_type,
@ -127,7 +134,10 @@ impl S3Client for S3ClientImpl {
.send()
.await
{
Ok(_) => Ok(()),
Ok(_) => {
trace!("put object to S3: {}", object_key);
Ok(())
},
Err(err) => match err {
SdkError::TimeoutError(_) | SdkError::DispatchFailure(_) | SdkError::ServiceError(_) => {
Err(WorkerError::S3ServiceUnavailable(format!(
@ -153,7 +163,10 @@ impl S3Client for S3ClientImpl {
.send()
.await
{
Ok(_) => Ok(()),
Ok(_) => {
trace!("deleted object from S3: {}", object_key);
Ok(())
},
Err(SdkError::ServiceError(service_err)) => Err(WorkerError::from(anyhow!(
"Failed to delete object from S3: {:?}",
service_err

View file

@ -288,7 +288,7 @@ async fn put_blob_handler(
let file_stream = ByteStream::from(content);
state
.bucket_storage
.put_blob(path, file_stream, content_type, file_size)
.put_blob_with_content_type(path, file_stream, content_type, file_size)
.await
.map_err(AppResponseError::from)?;
@ -562,7 +562,7 @@ async fn put_blob_handler_v1(
let file_stream = ByteStream::from(content);
state
.bucket_storage
.put_blob(
.put_blob_with_content_type(
BlobPathV1::from((path, file_id)),
file_stream,
content_type,

View file

@ -1,5 +1,5 @@
use access_control::act::Action;
use actix_web::web::{Bytes, Payload};
use actix_web::web::{Bytes, Path, Payload};
use actix_web::web::{Data, Json, PayloadConfig};
use actix_web::{web, Scope};
use actix_web::{HttpRequest, Result};
@ -1122,7 +1122,7 @@ async fn create_collab_snapshot_handler(
.get_user_uid(&user_uuid)
.await
.map_err(AppResponseError::from)?;
let encoded_collab_v1 = state
let data = state
.collab_access_control_storage
.get_encode_collab(
GetCollabOrigin::User { uid },
@ -1130,15 +1130,14 @@ async fn create_collab_snapshot_handler(
true,
)
.await?
.encode_to_bytes()
.unwrap();
.doc_state;
let meta = state
.collab_access_control_storage
.create_snapshot(InsertSnapshotParams {
object_id,
workspace_id,
encoded_collab_v1,
data,
collab_type,
})
.await?;
@ -1152,10 +1151,10 @@ async fn get_all_collab_snapshot_list_handler(
path: web::Path<(String, String)>,
state: Data<AppState>,
) -> Result<Json<AppResponse<AFSnapshotMetas>>> {
let (_, object_id) = path.into_inner();
let (workspace_id, object_id) = path.into_inner();
let data = state
.collab_access_control_storage
.get_collab_snapshot_list(&object_id)
.get_collab_snapshot_list(&workspace_id, &object_id)
.await
.map_err(AppResponseError::from)?;
Ok(Json(AppResponse::Ok().with_data(data)))
@ -1164,9 +1163,11 @@ async fn get_all_collab_snapshot_list_handler(
#[instrument(level = "debug", skip(payload, state), err)]
async fn batch_get_collab_handler(
user_uuid: UserUuid,
path: Path<String>,
state: Data<AppState>,
payload: Json<BatchQueryCollabParams>,
) -> Result<Json<AppResponse<BatchQueryCollabResult>>> {
let workspace_id = path.into_inner();
let uid = state
.user_cache
.get_user_uid(&user_uuid)
@ -1175,7 +1176,7 @@ async fn batch_get_collab_handler(
let result = BatchQueryCollabResult(
state
.collab_access_control_storage
.batch_get_collab(&uid, payload.into_inner().0, false)
.batch_get_collab(&uid, &workspace_id, payload.into_inner().0, false)
.await,
);
Ok(Json(AppResponse::Ok().with_data(result)))
@ -1259,7 +1260,11 @@ async fn add_collab_member_handler(
state: Data<AppState>,
) -> Result<Json<AppResponse<()>>> {
let payload = payload.into_inner();
if !state.collab_cache.is_exist(&payload.object_id).await? {
if !state
.collab_cache
.is_exist(&payload.workspace_id, &payload.object_id)
.await?
{
return Err(
AppError::RecordNotFound(format!(
"Fail to insert collab member. The Collab with object_id {} does not exist",
@ -1286,7 +1291,11 @@ async fn update_collab_member_handler(
) -> Result<Json<AppResponse<()>>> {
let payload = payload.into_inner();
if !state.collab_cache.is_exist(&payload.object_id).await? {
if !state
.collab_cache
.is_exist(&payload.workspace_id, &payload.object_id)
.await?
{
return Err(
AppError::RecordNotFound(format!(
"Fail to update collab member. The Collab with object_id {} does not exist",

View file

@ -281,7 +281,11 @@ pub async fn init_state(config: &Config, rt_cmd_tx: CLCommandSender) -> Result<A
} else {
Arc::new(NoOpsRealtimeCollabAccessControlImpl::new())
};
let collab_cache = CollabCache::new(redis_conn_manager.clone(), pg_pool.clone());
let collab_cache = CollabCache::new(
redis_conn_manager.clone(),
pg_pool.clone(),
s3_client.clone(),
);
let collab_storage_access_control = CollabStorageAccessControlImpl {
collab_access_control: collab_access_control.clone(),
@ -289,8 +293,8 @@ pub async fn init_state(config: &Config, rt_cmd_tx: CLCommandSender) -> Result<A
cache: collab_cache.clone(),
};
let snapshot_control = SnapshotControl::new(
redis_conn_manager.clone(),
pg_pool.clone(),
s3_client.clone(),
metrics.collab_metrics.clone(),
)
.await;

View file

@ -603,7 +603,7 @@ pub async fn list_database_row_details(
}
let database_row_details = collab_storage
.batch_get_collab(&uid, query_collabs, true)
.batch_get_collab(&uid, &workspace_uuid_str, query_collabs, true)
.await
.into_iter()
.flat_map(|(id, result)| match result {

View file

@ -1173,7 +1173,7 @@ async fn get_page_collab_data_for_database(
})
.collect();
let row_query_collab_results = collab_access_control_storage
.batch_get_collab(&uid, queries, true)
.batch_get_collab(&uid, &workspace_id.to_string(), queries, true)
.await;
let row_data = tokio::task::spawn_blocking(move || {
let row_collabs: HashMap<String, Vec<u8>> = row_query_collab_results

View file

@ -55,18 +55,33 @@ async fn batch_insert_collab_success_test() {
let mut test_client = TestClient::new_user().await;
let workspace_id = test_client.workspace_id().await;
let mock_encoded_collab_v1 = vec![
test_encode_collab_v1("1", "title", &generate_random_string(1024)),
test_encode_collab_v1("2", "title", &generate_random_string(3 * 1024)),
test_encode_collab_v1("3", "title", &generate_random_string(600 * 1024)),
test_encode_collab_v1("4", "title", &generate_random_string(800 * 1024)),
test_encode_collab_v1("5", "title", &generate_random_string(1024 * 1024)),
];
let mut mock_encoded_collab = vec![];
for _ in 0..200 {
let object_id = Uuid::new_v4().to_string();
let encoded_collab_v1 =
test_encode_collab_v1(&object_id, "title", &generate_random_string(2 * 1024));
mock_encoded_collab.push(encoded_collab_v1);
}
let params_list = (0..5)
.map(|i| CollabParams {
for _ in 0..30 {
let object_id = Uuid::new_v4().to_string();
let encoded_collab_v1 =
test_encode_collab_v1(&object_id, "title", &generate_random_string(10 * 1024));
mock_encoded_collab.push(encoded_collab_v1);
}
for _ in 0..10 {
let object_id = Uuid::new_v4().to_string();
let encoded_collab_v1 =
test_encode_collab_v1(&object_id, "title", &generate_random_string(800 * 1024));
mock_encoded_collab.push(encoded_collab_v1);
}
let params_list = mock_encoded_collab
.iter()
.map(|encoded_collab_v1| CollabParams {
object_id: Uuid::new_v4().to_string(),
encoded_collab_v1: mock_encoded_collab_v1[i].encode_to_bytes().unwrap().into(),
encoded_collab_v1: encoded_collab_v1.encode_to_bytes().unwrap().into(),
collab_type: CollabType::Unknown,
embeddings: None,
})
@ -94,7 +109,9 @@ async fn batch_insert_collab_success_test() {
let encoded_collab = result.0.get(&params.object_id).unwrap();
match encoded_collab {
QueryCollabResult::Success { encode_collab_v1 } => {
assert_eq!(encode_collab_v1, &params.encoded_collab_v1)
let actual = EncodedCollab::decode_from_bytes(encode_collab_v1.as_ref()).unwrap();
let expected = EncodedCollab::decode_from_bytes(params.encoded_collab_v1.as_ref()).unwrap();
assert_eq!(actual.doc_state, expected.doc_state);
},
QueryCollabResult::Failed { error } => {
panic!("Failed to get collab: {:?}", error);
@ -102,7 +119,7 @@ async fn batch_insert_collab_success_test() {
}
}
assert_eq!(result.0.values().len(), 5);
assert_eq!(result.0.values().len(), 240);
}
#[tokio::test]