refactor: Access control (#361)

* refactor: access level

* refactor: access control

* refactor: enforce action

* refactor: collab cache

* chore: fix test

* chore: fix test

* chore: fix test

* chore: fix test

* chore: commit migration file

* chore: commit migration file
This commit is contained in:
Nathan.fooo 2024-03-03 12:55:12 +08:00 committed by GitHub
parent 2cf857bd00
commit 0e57de98d8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
48 changed files with 1628 additions and 2398 deletions

5
Cargo.lock generated
View file

@ -325,7 +325,7 @@ dependencies = [
"gotrue",
"gotrue-entity",
"jwt",
"redis 0.23.3",
"redis 0.24.0",
"reqwest",
"serde",
"serde_json",
@ -1253,7 +1253,6 @@ dependencies = [
"gotrue",
"gotrue-entity",
"governor",
"log",
"mime",
"mime_guess",
"parking_lot 0.12.1",
@ -4200,8 +4199,6 @@ dependencies = [
"percent-encoding",
"pin-project-lite",
"ryu",
"sha1_smol",
"socket2 0.4.10",
"tokio",
"tokio-native-tls",
"tokio-retry",

View file

@ -17,7 +17,7 @@ askama = "0.12.1"
axum-extra = { version = "0.9.2", features = ["cookie"] }
serde.workspace = true
serde_json.workspace = true
redis = { version = "0.23.3", features = [ "aio", "tokio-comp", "connection-manager"] }
redis = { version = "0.24.0", features = [ "aio", "tokio-comp", "connection-manager"] }
uuid = { version = "1.6.1", features = ["v4"] }
dotenvy = "0.15.7"
reqwest = "0.11.23"

View file

@ -42,7 +42,7 @@ async fn main() {
let redis_client = redis::Client::open(config.redis_url)
.expect("failed to create redis client")
.get_tokio_connection_manager()
.get_connection_manager()
.await
.expect("failed to get redis connection manager");
info!("Redis client initialized.");

View file

@ -3,6 +3,7 @@ pub mod gotrue;
#[cfg(feature = "gotrue_error")]
use crate::gotrue::GoTrueError;
use reqwest::StatusCode;
use serde::Serialize;
use thiserror::Error;
@ -57,8 +58,8 @@ pub enum AppError {
#[error("Not Logged In:{0}")]
NotLoggedIn(String),
#[error("Not Enough Permissions:{0}")]
NotEnoughPermissions(String),
#[error("{user}: do not have permissions to {action}")]
NotEnoughPermissions { user: String, action: String },
#[cfg(feature = "s3_error")]
#[error(transparent)]
@ -110,7 +111,7 @@ pub enum AppError {
impl AppError {
pub fn is_not_enough_permissions(&self) -> bool {
matches!(self, AppError::NotEnoughPermissions(_))
matches!(self, AppError::NotEnoughPermissions { .. })
}
pub fn is_record_not_found(&self) -> bool {
@ -142,7 +143,7 @@ impl AppError {
AppError::InvalidOAuthProvider(_) => ErrorCode::InvalidOAuthProvider,
AppError::InvalidRequest(_) => ErrorCode::InvalidRequest,
AppError::NotLoggedIn(_) => ErrorCode::NotLoggedIn,
AppError::NotEnoughPermissions(_) => ErrorCode::NotEnoughPermissions,
AppError::NotEnoughPermissions { .. } => ErrorCode::NotEnoughPermissions,
#[cfg(feature = "s3_error")]
AppError::S3Error(_) => ErrorCode::S3Error,
AppError::StorageSpaceNotEnough => ErrorCode::StorageSpaceNotEnough,
@ -179,7 +180,11 @@ impl From<reqwest::Error> for AppError {
}
if error.is_request() {
return AppError::InvalidRequest(error.to_string());
return if error.status() == Some(StatusCode::PAYLOAD_TOO_LARGE) {
AppError::PayloadTooLarge(error.to_string())
} else {
AppError::InvalidRequest(error.to_string())
};
}
AppError::Unhandled(error.to_string())
}

View file

@ -44,7 +44,6 @@ database-entity.workspace = true
app-error = { workspace = true, features = ["tokio_error", "bincode_error"] }
scraper = { version = "0.17.1", optional = true }
governor = { version = "0.6.0" }
log = "0.4.20"
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
tokio-retry = "0.3"
@ -67,7 +66,3 @@ again = "0.1.2"
collab-sync = ["collab", "yrs"]
test_util = ["scraper"]
template = ["workspace-template"]
[profile.dev]
debug = true

View file

@ -11,7 +11,6 @@ use collab::core::origin::CollabOrigin;
use collab::preclude::Collab;
use futures_util::{SinkExt, StreamExt};
use log::trace;
use realtime_entity::collab_msg::{
AckCode, ClientCollabMessage, InitSync, ServerCollabMessage, ServerInit, UpdateSync,
};
@ -24,7 +23,7 @@ use std::sync::{Arc, Weak};
use std::time::{Duration, Instant};
use tokio::sync::{watch, Mutex};
use tokio_stream::wrappers::WatchStream;
use tracing::{error, info, warn};
use tracing::{error, info, trace, warn};
use yrs::encoding::read::Cursor;
use yrs::updates::decoder::DecoderV1;
use yrs::updates::encoder::{Encoder, EncoderV1};

View file

@ -382,6 +382,22 @@ impl AFAccessLevel {
}
}
impl From<&AFRole> for AFAccessLevel {
fn from(value: &AFRole) -> Self {
match value {
AFRole::Owner => AFAccessLevel::FullAccess,
AFRole::Member => AFAccessLevel::ReadAndWrite,
AFRole::Guest => AFAccessLevel::ReadOnly,
}
}
}
impl From<AFRole> for AFAccessLevel {
fn from(value: AFRole) -> Self {
AFAccessLevel::from(&value)
}
}
impl From<i32> for AFAccessLevel {
fn from(value: i32) -> Self {
// Can't modify the value of the enum

View file

@ -20,19 +20,6 @@ use std::{ops::DerefMut, str::FromStr};
use tracing::{error, event, instrument};
use uuid::Uuid;
#[inline]
pub async fn collab_exists(pg_pool: &PgPool, oid: &str) -> Result<bool, sqlx::Error> {
let result = sqlx::query_scalar!(
r#"
SELECT EXISTS (SELECT 1 FROM af_collab WHERE oid = $1 LIMIT 1)
"#,
&oid,
)
.fetch_one(pg_pool)
.await;
transform_record_not_found_error(result)
}
/// Inserts a new row into the `af_collab` table or updates an existing row if it matches the
/// provided `object_id`.Additionally, if the row is being inserted for the first time, a corresponding
/// entry will be added to the `af_collab_member` table.
@ -472,11 +459,15 @@ pub async fn insert_collab_member(
Ok(())
}
pub async fn delete_collab_member(uid: i64, oid: &str, pg_pool: &PgPool) -> Result<(), AppError> {
pub async fn delete_collab_member(
uid: i64,
oid: &str,
txn: &mut Transaction<'_, sqlx::Postgres>,
) -> Result<(), AppError> {
sqlx::query("DELETE FROM af_collab_member WHERE uid = $1 AND oid = $2")
.bind(uid)
.bind(oid)
.execute(pg_pool)
.execute(txn.deref_mut())
.await?;
Ok(())
}
@ -504,10 +495,16 @@ pub async fn select_collab_members(
) -> Result<Vec<AFCollabMember>, AppError> {
let members = sqlx::query(
r#"
SELECT af_collab_member.uid, af_collab_member.oid, af_permissions.id, af_permissions.name, af_permissions.access_level, af_permissions.description
SELECT af_collab_member.uid,
af_collab_member.oid,
af_permissions.id,
af_permissions.name,
af_permissions.access_level,
af_permissions.description
FROM af_collab_member
JOIN af_permissions ON af_collab_member.permission_id = af_permissions.id
WHERE af_collab_member.oid = $1
ORDER BY af_collab_member.created_at ASC
"#,
)
.bind(oid)

View file

@ -1,22 +1,15 @@
use crate::collab::{collab_db_ops, is_collab_exists};
use anyhow::anyhow;
use app_error::AppError;
use async_trait::async_trait;
use collab::core::collab_plugin::EncodedCollab;
use database_entity::dto::{
AFAccessLevel, AFRole, AFSnapshotMeta, AFSnapshotMetas, CollabParams, CreateCollabParams,
AFAccessLevel, AFSnapshotMeta, AFSnapshotMetas, CollabParams, CreateCollabParams,
InsertSnapshotParams, QueryCollab, QueryCollabParams, QueryCollabResult, SnapshotData,
};
use sqlx::{Executor, PgPool, Postgres, Transaction};
use sqlx::Transaction;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::sleep;
use tracing::{debug, event, warn, Level};
use validator::Validate;
pub const COLLAB_SNAPSHOT_LIMIT: i64 = 30;
pub const SNAPSHOT_PER_HOUR: i64 = 6;
@ -26,29 +19,18 @@ pub type DatabaseResult<T, E = AppError> = core::result::Result<T, E>;
/// of the Collab object.
#[async_trait]
pub trait CollabStorageAccessControl: Send + Sync + 'static {
/// Checks if the user with the given ID can access the [Collab] with the given ID.
async fn get_or_refresh_collab_access_level<'a, E: Executor<'a, Database = Postgres>>(
&self,
uid: &i64,
oid: &str,
executor: E,
) -> Result<AFAccessLevel, AppError>;
/// Updates the cache of the access level of the user for given collab object.
async fn cache_collab_access_level(
&self,
uid: &i64,
oid: &str,
level: AFAccessLevel,
) -> Result<(), AppError>;
async fn update_policy(&self, uid: &i64, oid: &str, level: AFAccessLevel)
-> Result<(), AppError>;
async fn enforce_read_collab(&self, uid: &i64, oid: &str) -> Result<bool, AppError>;
async fn enforce_write_collab(&self, uid: &i64, oid: &str) -> Result<bool, AppError>;
async fn enforce_delete(&self, uid: &i64, oid: &str) -> Result<bool, AppError>;
/// Returns the role of the user in the workspace.
async fn get_user_workspace_role<'a, E: Executor<'a, Database = Postgres>>(
&self,
uid: &i64,
workspace_id: &str,
executor: E,
) -> Result<AFRole, AppError>;
async fn enforce_write_workspace(&self, uid: &i64, workspace_id: &str) -> Result<bool, AppError>;
}
/// Represents a storage mechanism for collaborations.
@ -57,11 +39,14 @@ pub trait CollabStorageAccessControl: Send + Sync + 'static {
/// Implementors of this trait should provide the actual storage logic, be it in-memory, file-based, database-backed, etc.
#[async_trait]
pub trait CollabStorage: Send + Sync + 'static {
fn config(&self) -> &WriteConfig;
fn encode_collab_mem_hit_rate(&self) -> f64;
async fn upsert_collab(&self, uid: &i64, params: CreateCollabParams) -> DatabaseResult<()>;
async fn insert_collab(
&self,
uid: &i64,
params: CreateCollabParams,
is_new: bool,
) -> DatabaseResult<()>;
/// Insert/update a new collaboration in the storage.
///
@ -72,7 +57,7 @@ pub trait CollabStorage: Send + Sync + 'static {
/// # Returns
///
/// * `Result<()>` - Returns `Ok(())` if the collaboration was created successfully, `Err` otherwise.
async fn upsert_collab_with_transaction(
async fn insert_or_update_collab(
&self,
workspace_id: &str,
uid: &i64,
@ -133,19 +118,20 @@ impl<T> CollabStorage for Arc<T>
where
T: CollabStorage,
{
fn config(&self) -> &WriteConfig {
self.as_ref().config()
}
fn encode_collab_mem_hit_rate(&self) -> f64 {
self.as_ref().encode_collab_mem_hit_rate()
}
async fn upsert_collab(&self, uid: &i64, params: CreateCollabParams) -> DatabaseResult<()> {
self.as_ref().upsert_collab(uid, params).await
async fn insert_collab(
&self,
uid: &i64,
params: CreateCollabParams,
is_new: bool,
) -> DatabaseResult<()> {
self.as_ref().insert_collab(uid, params, is_new).await
}
async fn upsert_collab_with_transaction(
async fn insert_or_update_collab(
&self,
workspace_id: &str,
uid: &i64,
@ -154,7 +140,7 @@ where
) -> DatabaseResult<()> {
self
.as_ref()
.upsert_collab_with_transaction(workspace_id, uid, params, transaction)
.insert_or_update_collab(workspace_id, uid, params, transaction)
.await
}
@ -210,175 +196,162 @@ where
self.as_ref().get_collab_snapshot_list(oid).await
}
}
#[derive(Debug, Clone)]
pub struct WriteConfig {
pub flush_per_update: u32,
}
impl Default for WriteConfig {
fn default() -> Self {
Self {
flush_per_update: 100,
}
}
}
#[derive(Clone)]
pub struct CollabStoragePgImpl {
pub pg_pool: PgPool,
config: WriteConfig,
}
impl CollabStoragePgImpl {
pub fn new(pg_pool: PgPool) -> Self {
let config = WriteConfig::default();
Self { pg_pool, config }
}
pub fn config(&self) -> &WriteConfig {
&self.config
}
pub async fn is_exist(&self, object_id: &str) -> bool {
collab_db_ops::collab_exists(&self.pg_pool, object_id)
.await
.unwrap_or(false)
}
pub async fn is_collab_exist(&self, oid: &str) -> DatabaseResult<bool> {
let is_exist = is_collab_exists(oid, &self.pg_pool).await?;
Ok(is_exist)
}
pub async fn upsert_collab_with_transaction(
&self,
workspace_id: &str,
uid: &i64,
params: CollabParams,
transaction: &mut Transaction<'_, sqlx::Postgres>,
) -> DatabaseResult<()> {
collab_db_ops::insert_into_af_collab(transaction, uid, workspace_id, &params).await?;
Ok(())
}
pub async fn get_collab_encoded(
&self,
_uid: &i64,
params: QueryCollabParams,
) -> Result<EncodedCollab, AppError> {
event!(
Level::INFO,
"Get encoded collab:{} from disk",
params.object_id
);
const MAX_ATTEMPTS: usize = 3;
let mut attempts = 0;
loop {
let result = collab_db_ops::select_blob_from_af_collab(
&self.pg_pool,
&params.collab_type,
&params.object_id,
)
.await;
match result {
Ok(data) => {
return tokio::task::spawn_blocking(move || {
EncodedCollab::decode_from_bytes(&data).map_err(|err| {
AppError::Internal(anyhow!("fail to decode data to EncodedCollab: {:?}", err))
})
})
.await?;
},
Err(e) => {
// Handle non-retryable errors immediately
if matches!(e, sqlx::Error::RowNotFound) {
let msg = format!("Can't find the row for query: {:?}", params);
return Err(AppError::RecordNotFound(msg));
}
// Increment attempts and retry if below MAX_ATTEMPTS and the error is retryable
if attempts < MAX_ATTEMPTS - 1 && matches!(e, sqlx::Error::PoolTimedOut) {
attempts += 1;
sleep(Duration::from_millis(500 * attempts as u64)).await;
continue;
} else {
return Err(e.into());
}
},
}
}
}
pub async fn batch_get_collab(
&self,
_uid: &i64,
queries: Vec<QueryCollab>,
) -> HashMap<String, QueryCollabResult> {
collab_db_ops::batch_select_collab_blob(&self.pg_pool, queries).await
}
pub async fn delete_collab(&self, _uid: &i64, object_id: &str) -> DatabaseResult<()> {
collab_db_ops::delete_collab(&self.pg_pool, object_id).await?;
Ok(())
}
pub async fn should_create_snapshot(&self, oid: &str) -> bool {
if oid.is_empty() {
warn!("unexpected empty object id when checking should_create_snapshot");
return false;
}
collab_db_ops::should_create_snapshot(oid, &self.pg_pool)
.await
.unwrap_or(false)
}
pub async fn create_snapshot(
&self,
params: InsertSnapshotParams,
) -> DatabaseResult<AFSnapshotMeta> {
params.validate()?;
debug!("create snapshot for object:{}", params.object_id);
match self.pg_pool.try_begin().await {
Ok(Some(transaction)) => {
let meta = collab_db_ops::create_snapshot_and_maintain_limit(
transaction,
&params.workspace_id,
&params.object_id,
&params.encoded_collab_v1,
COLLAB_SNAPSHOT_LIMIT,
)
.await?;
Ok(meta)
},
_ => Err(AppError::Internal(anyhow!(
"fail to acquire transaction to create snapshot for object:{}",
params.object_id,
))),
}
}
pub async fn get_collab_snapshot(&self, snapshot_id: &i64) -> DatabaseResult<SnapshotData> {
match collab_db_ops::select_snapshot(&self.pg_pool, snapshot_id).await? {
None => Err(AppError::RecordNotFound(format!(
"Can't find the snapshot with id:{}",
snapshot_id
))),
Some(row) => Ok(SnapshotData {
object_id: row.oid,
encoded_collab_v1: row.blob,
workspace_id: row.workspace_id.to_string(),
}),
}
}
/// Returns list of snapshots for given object_id in descending order of creation time.
pub async fn get_collab_snapshot_list(&self, oid: &str) -> DatabaseResult<AFSnapshotMetas> {
let metas = collab_db_ops::get_all_collab_snapshot_meta(&self.pg_pool, oid).await?;
Ok(metas)
}
}
//
// #[derive(Clone)]
// pub struct CollabDiskCache {
// pub pg_pool: PgPool,
// config: WriteConfig,
// }
//
// impl CollabDiskCache {
// pub fn new(pg_pool: PgPool) -> Self {
// let config = WriteConfig::default();
// Self { pg_pool, config }
// }
// pub fn config(&self) -> &WriteConfig {
// &self.config
// }
//
// pub async fn is_exist(&self, object_id: &str) -> bool {
// collab_db_ops::collab_exists(&self.pg_pool, object_id)
// .await
// .unwrap_or(false)
// }
//
// pub async fn is_collab_exist(&self, oid: &str) -> DatabaseResult<bool> {
// let is_exist = is_collab_exists(oid, &self.pg_pool).await?;
// Ok(is_exist)
// }
//
// pub async fn upsert_collab_with_transaction(
// &self,
// workspace_id: &str,
// uid: &i64,
// params: CollabParams,
// transaction: &mut Transaction<'_, sqlx::Postgres>,
// ) -> DatabaseResult<()> {
// collab_db_ops::insert_into_af_collab(transaction, uid, workspace_id, &params).await?;
// Ok(())
// }
//
// pub async fn get_collab_encoded(
// &self,
// _uid: &i64,
// params: QueryCollabParams,
// ) -> Result<EncodedCollab, AppError> {
// event!(
// Level::INFO,
// "Get encoded collab:{} from disk",
// params.object_id
// );
//
// const MAX_ATTEMPTS: usize = 3;
// let mut attempts = 0;
//
// loop {
// let result = collab_db_ops::select_blob_from_af_collab(
// &self.pg_pool,
// &params.collab_type,
// &params.object_id,
// )
// .await;
//
// match result {
// Ok(data) => {
// return tokio::task::spawn_blocking(move || {
// EncodedCollab::decode_from_bytes(&data).map_err(|err| {
// AppError::Internal(anyhow!("fail to decode data to EncodedCollab: {:?}", err))
// })
// })
// .await?;
// },
// Err(e) => {
// // Handle non-retryable errors immediately
// if matches!(e, sqlx::Error::RowNotFound) {
// let msg = format!("Can't find the row for query: {:?}", params);
// return Err(AppError::RecordNotFound(msg));
// }
//
// // Increment attempts and retry if below MAX_ATTEMPTS and the error is retryable
// if attempts < MAX_ATTEMPTS - 1 && matches!(e, sqlx::Error::PoolTimedOut) {
// attempts += 1;
// sleep(Duration::from_millis(500 * attempts as u64)).await;
// continue;
// } else {
// return Err(e.into());
// }
// },
// }
// }
// }
//
// pub async fn batch_get_collab(
// &self,
// _uid: &i64,
// queries: Vec<QueryCollab>,
// ) -> HashMap<String, QueryCollabResult> {
// collab_db_ops::batch_select_collab_blob(&self.pg_pool, queries).await
// }
//
// pub async fn delete_collab(&self, _uid: &i64, object_id: &str) -> DatabaseResult<()> {
// collab_db_ops::delete_collab(&self.pg_pool, object_id).await?;
// Ok(())
// }
//
// pub async fn should_create_snapshot(&self, oid: &str) -> bool {
// if oid.is_empty() {
// warn!("unexpected empty object id when checking should_create_snapshot");
// return false;
// }
//
// collab_db_ops::should_create_snapshot(oid, &self.pg_pool)
// .await
// .unwrap_or(false)
// }
//
// pub async fn create_snapshot(
// &self,
// params: InsertSnapshotParams,
// ) -> DatabaseResult<AFSnapshotMeta> {
// params.validate()?;
//
// debug!("create snapshot for object:{}", params.object_id);
// match self.pg_pool.try_begin().await {
// Ok(Some(transaction)) => {
// let meta = collab_db_ops::create_snapshot_and_maintain_limit(
// transaction,
// &params.workspace_id,
// &params.object_id,
// &params.encoded_collab_v1,
// COLLAB_SNAPSHOT_LIMIT,
// )
// .await?;
// Ok(meta)
// },
// _ => Err(AppError::Internal(anyhow!(
// "fail to acquire transaction to create snapshot for object:{}",
// params.object_id,
// ))),
// }
// }
//
// pub async fn get_collab_snapshot(&self, snapshot_id: &i64) -> DatabaseResult<SnapshotData> {
// match collab_db_ops::select_snapshot(&self.pg_pool, snapshot_id).await? {
// None => Err(AppError::RecordNotFound(format!(
// "Can't find the snapshot with id:{}",
// snapshot_id
// ))),
// Some(row) => Ok(SnapshotData {
// object_id: row.oid,
// encoded_collab_v1: row.blob,
// workspace_id: row.workspace_id.to_string(),
// }),
// }
// }
//
// /// Returns list of snapshots for given object_id in descending order of creation time.
// pub async fn get_collab_snapshot_list(&self, oid: &str) -> DatabaseResult<AFSnapshotMetas> {
// let metas = collab_db_ops::get_all_collab_snapshot_meta(&self.pg_pool, oid).await?;
// Ok(metas)
// }
// }

View file

@ -204,7 +204,7 @@ pub async fn select_user_can_edit_collab(
}
#[inline]
pub async fn insert_workspace_member_with_txn(
pub async fn upsert_workspace_member_with_txn(
txn: &mut Transaction<'_, sqlx::Postgres>,
workspace_id: &uuid::Uuid,
member_email: &str,
@ -237,12 +237,8 @@ pub async fn upsert_workspace_member(
pool: &PgPool,
workspace_id: &Uuid,
email: &str,
role: Option<AFRole>,
role: AFRole,
) -> Result<(), sqlx::Error> {
if role.is_none() {
return Ok(());
}
event!(
tracing::Level::TRACE,
"update workspace member: workspace_id:{}, uid {:?}, role:{:?}",
@ -251,7 +247,7 @@ pub async fn upsert_workspace_member(
role
);
let role_id: i32 = role.unwrap().into();
let role_id: i32 = role.into();
sqlx::query!(
r#"
UPDATE af_workspace_member
@ -273,7 +269,6 @@ pub async fn upsert_workspace_member(
#[inline]
pub async fn delete_workspace_members(
_user_uuid: &Uuid,
txn: &mut Transaction<'_, sqlx::Postgres>,
workspace_id: &Uuid,
member_email: &str,
@ -298,9 +293,10 @@ pub async fn delete_workspace_members(
.unwrap_or(false);
if is_owner {
return Err(AppError::NotEnoughPermissions(
"Owner cannot be deleted".to_string(),
));
return Err(AppError::NotEnoughPermissions {
user: member_email.to_string(),
action: format!("delete member from workspace {}", workspace_id),
});
}
sqlx::query!(

View file

@ -1,4 +1,4 @@
use crate::collaborate::{CollabAccessControl, RealtimeServer};
use crate::collaborate::{RealtimeAccessControl, RealtimeServer};
use crate::entities::{ClientMessage, Connect, Disconnect, RealtimeMessage, RealtimeUser};
use crate::error::RealtimeError;
use actix::{
@ -24,7 +24,7 @@ const RATE_LIMIT_INTERVAL: Duration = Duration::from_secs(1);
pub struct RealtimeClient<
U: Unpin + RealtimeUser,
S: Unpin + 'static,
AC: Unpin + CollabAccessControl,
AC: Unpin + RealtimeAccessControl,
> {
session_id: String,
user: U,
@ -41,7 +41,7 @@ impl<U, S, AC> RealtimeClient<U, S, AC>
where
U: Unpin + RealtimeUser + Clone,
S: CollabStorage + Unpin,
AC: CollabAccessControl + Unpin,
AC: RealtimeAccessControl + Unpin,
{
pub fn new(
user: U,
@ -138,7 +138,7 @@ impl<U, S, P> Actor for RealtimeClient<U, S, P>
where
U: Unpin + RealtimeUser,
S: Unpin + CollabStorage,
P: CollabAccessControl + Unpin,
P: RealtimeAccessControl + Unpin,
{
type Context = ws::WebsocketContext<Self>;
@ -220,7 +220,7 @@ impl<U, S, AC> Handler<RealtimeMessage> for RealtimeClient<U, S, AC>
where
U: Unpin + RealtimeUser,
S: Unpin + CollabStorage,
AC: CollabAccessControl + Unpin,
AC: RealtimeAccessControl + Unpin,
{
type Result = ();
@ -235,7 +235,7 @@ impl<U, S, AC> StreamHandler<Result<ws::Message, ws::ProtocolError>> for Realtim
where
U: Unpin + RealtimeUser + Clone,
S: Unpin + CollabStorage,
AC: CollabAccessControl + Unpin,
AC: RealtimeAccessControl + Unpin,
{
fn handle(&mut self, msg: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
let now = Instant::now();

View file

@ -1,5 +1,5 @@
use crate::collaborate::group::CollabGroup;
use crate::collaborate::{CollabAccessControl, CollabStoragePlugin};
use crate::collaborate::{CollabStoragePlugin, RealtimeAccessControl};
use crate::entities::RealtimeUser;
use anyhow::Error;
use collab::core::origin::CollabOrigin;
@ -20,7 +20,7 @@ impl<S, U, AC> AllCollabGroup<S, U, AC>
where
S: CollabStorage,
U: RealtimeUser,
AC: CollabAccessControl,
AC: RealtimeAccessControl,
{
pub fn new(storage: Arc<S>, access_control: Arc<AC>) -> Self {
Self {

View file

@ -1,7 +1,7 @@
use crate::collaborate::all_group::AllCollabGroup;
use crate::collaborate::group_sub::{CollabUserMessage, SubscribeGroup};
use crate::collaborate::{
broadcast_client_collab_message, CollabAccessControl, CollabClientStream,
broadcast_client_collab_message, CollabClientStream, RealtimeAccessControl,
};
use crate::entities::{Editing, RealtimeUser};
use crate::error::RealtimeError;
@ -42,7 +42,7 @@ impl<S, U, AC> GroupCommandRunner<S, U, AC>
where
S: CollabStorage,
U: RealtimeUser,
AC: CollabAccessControl,
AC: RealtimeAccessControl,
{
pub async fn run(mut self, object_id: String) {
let mut receiver = self.recv.take().expect("Only take once");

View file

@ -1,5 +1,5 @@
use crate::collaborate::all_group::AllCollabGroup;
use crate::collaborate::{CollabAccessControl, CollabClientStream};
use crate::collaborate::{CollabClientStream, RealtimeAccessControl};
use crate::entities::{Editing, RealtimeUser};
use crate::error::StreamError;
use crate::util::channel_ext::UnboundedSenderSink;
@ -30,7 +30,7 @@ impl<'a, S, U, AC> SubscribeGroup<'a, S, U, AC>
where
U: RealtimeUser,
S: CollabStorage,
AC: CollabAccessControl,
AC: RealtimeAccessControl,
{
fn get_origin(collab_message: &ClientCollabMessage) -> &CollabOrigin {
collab_message.origin()
@ -40,8 +40,7 @@ where
object_id: &'b str,
client_stream: &'b mut CollabClientStream,
client_uid: i64,
sink_permission_service: Arc<AC>,
stream_permission_service: Arc<AC>,
access_control: Arc<AC>,
) -> (
UnboundedSenderSink<CollabMessage>,
ReceiverStream<Result<ClientCollabMessage, StreamError>>,
@ -49,6 +48,8 @@ where
where
'a: 'b,
{
let sink_access_control = access_control.clone();
let stream_access_control = access_control.clone();
let (sink, stream) = client_stream.client_channel::<CollabMessage, _, _>(
object_id,
move |object_id, msg| {
@ -62,9 +63,9 @@ where
}
let object_id = object_id.to_string();
let permission_service = sink_permission_service.clone();
let clone_sink_access_control = sink_access_control.clone();
Box::pin(async move {
match permission_service
match clone_sink_access_control
.can_receive_collab_update(&client_uid, &object_id)
.await
{
@ -96,22 +97,22 @@ where
let is_init = msg.is_init_msg();
let object_id = object_id.to_string();
let cloned_stream_permission_service = stream_permission_service.clone();
let cloned_stream_access_control = stream_access_control.clone();
Box::pin(async move {
// If the message is init sync, and it's allow the send to the group.
// If the message is init sync, and it's allow to send to the group.
if is_init {
return true;
}
match cloned_stream_permission_service
match cloned_stream_access_control
.can_send_collab_update(&client_uid, &object_id)
.await
{
Ok(is_allowed) => {
if !is_allowed {
trace!(
"client:{} is not allowed to send {} updates",
"client:{} is not allowed to edit {} updates",
client_uid,
object_id,
);
@ -138,7 +139,7 @@ impl<'a, S, U, AC> SubscribeGroup<'a, S, U, AC>
where
U: RealtimeUser,
S: CollabStorage,
AC: CollabAccessControl,
AC: RealtimeAccessControl,
{
pub(crate) async fn run(self) {
let CollabUserMessage {
@ -172,7 +173,6 @@ where
client_stream.value_mut(),
client_uid,
self.access_control.clone(),
self.access_control.clone(),
);
collab_group
.subscribe(user, origin.clone(), sink, stream)

View file

@ -1,10 +1,5 @@
use app_error::AppError;
use async_trait::async_trait;
use database_entity::dto::AFAccessLevel;
use reqwest::Method;
use std::sync::Arc;
use tracing::instrument;
#[derive(Debug)]
pub enum CollabUserId<'a> {
@ -25,27 +20,7 @@ impl<'a> From<&'a uuid::Uuid> for CollabUserId<'a> {
}
#[async_trait]
pub trait CollabAccessControl: Sync + Send + 'static {
/// Return the access level of the user in the collab
async fn get_collab_access_level(&self, uid: &i64, oid: &str) -> Result<AFAccessLevel, AppError>;
async fn insert_collab_access_level(
&self,
uid: &i64,
oid: &str,
level: AFAccessLevel,
) -> Result<(), AppError>;
/// Return true if the user from the HTTP request is allowed to access the collab object.
/// This function will be called very frequently, so it should be very fast.
///
async fn can_access_http_method(
&self,
uid: &i64,
oid: &str,
method: &Method,
) -> Result<bool, AppError>;
pub trait RealtimeAccessControl: Sync + Send + 'static {
/// Return true if the user is allowed to send the message.
/// This function will be called very frequently, so it should be very fast.
///
@ -61,43 +36,3 @@ pub trait CollabAccessControl: Sync + Send + 'static {
/// The user can recv the message if the user is the member of the collab object
async fn can_receive_collab_update(&self, uid: &i64, oid: &str) -> Result<bool, AppError>;
}
//
#[async_trait]
impl<T> CollabAccessControl for Arc<T>
where
T: CollabAccessControl,
{
#[instrument(level = "debug", skip_all)]
async fn get_collab_access_level(&self, uid: &i64, oid: &str) -> Result<AFAccessLevel, AppError> {
self.as_ref().get_collab_access_level(uid, oid).await
}
async fn insert_collab_access_level(
&self,
uid: &i64,
oid: &str,
level: AFAccessLevel,
) -> Result<(), AppError> {
self
.as_ref()
.insert_collab_access_level(uid, oid, level)
.await
}
async fn can_access_http_method(
&self,
uid: &i64,
oid: &str,
method: &Method,
) -> Result<bool, AppError> {
self.as_ref().can_access_http_method(uid, oid, method).await
}
async fn can_send_collab_update(&self, uid: &i64, oid: &str) -> Result<bool, AppError> {
self.as_ref().can_send_collab_update(uid, oid).await
}
async fn can_receive_collab_update(&self, uid: &i64, oid: &str) -> Result<bool, AppError> {
self.as_ref().can_receive_collab_update(uid, oid).await
}
}

View file

@ -3,7 +3,7 @@ use app_error::AppError;
use async_trait::async_trait;
use std::fmt::Display;
use crate::collaborate::CollabAccessControl;
use crate::collaborate::RealtimeAccessControl;
use anyhow::anyhow;
use collab::core::collab::TransactionMutExt;
@ -15,9 +15,7 @@ use collab_document::document::check_document_is_valid;
use collab_entity::CollabType;
use collab_folder::check_folder_is_valid;
use database::collab::CollabStorage;
use database_entity::dto::{
AFAccessLevel, CreateCollabParams, InsertSnapshotParams, QueryCollabParams,
};
use database_entity::dto::{CreateCollabParams, InsertSnapshotParams, QueryCollabParams};
use md5::Digest;
use parking_lot::Mutex;
use std::sync::atomic::{AtomicBool, AtomicI64, AtomicU32, Ordering};
@ -34,6 +32,7 @@ pub struct CollabStoragePlugin<S, AC> {
storage: Arc<S>,
edit_state: Arc<CollabEditState>,
collab_type: CollabType,
#[allow(dead_code)]
access_control: Arc<AC>,
latest_collab_md5: Mutex<Option<Digest>>,
}
@ -41,7 +40,7 @@ pub struct CollabStoragePlugin<S, AC> {
impl<S, AC> CollabStoragePlugin<S, AC>
where
S: CollabStorage,
AC: CollabAccessControl,
AC: RealtimeAccessControl,
{
pub fn new(
uid: i64,
@ -68,11 +67,6 @@ where
async fn insert_new_collab(&self, doc: &Doc, object_id: &str) -> Result<(), AppError> {
match doc.get_encoded_collab_v1().encode_to_bytes() {
Ok(encoded_collab_v1) => {
let _ = self
.access_control
.insert_collab_access_level(&self.uid, object_id, AFAccessLevel::FullAccess)
.await;
let params = CreateCollabParams {
object_id: object_id.to_string(),
encoded_collab_v1,
@ -83,7 +77,7 @@ where
self
.storage
.upsert_collab(&self.uid, params)
.insert_collab(&self.uid, params, true)
.await
.map_err(|err| {
error!("fail to create new collab in plugin: {:?}", err);
@ -128,7 +122,7 @@ async fn init_collab(
impl<S, AC> CollabPlugin for CollabStoragePlugin<S, AC>
where
S: CollabStorage,
AC: CollabAccessControl,
AC: RealtimeAccessControl,
{
async fn init(&self, object_id: &str, _origin: &CollabOrigin, doc: &Doc) {
let params = QueryCollabParams::new(object_id, self.collab_type.clone(), &self.workspace_id);
@ -215,10 +209,7 @@ where
}
trace!("{} edit state:{}", object_id, self.edit_state);
if self
.edit_state
.should_flush(self.storage.config().flush_per_update, 3 * 60)
{
if self.edit_state.should_flush(100, 3 * 60) {
self.edit_state.tick();
let _object_id = object_id.to_string();
// let weak_group = self.group.clone();
@ -271,7 +262,7 @@ where
let uid = self.uid;
tokio::spawn(async move {
info!("[realtime] flush collab: {}", params.object_id);
match storage.upsert_collab(&uid, params).await {
match storage.insert_collab(&uid, params, false).await {
Ok(_) => {},
Err(err) => error!("Failed to save collab: {:?}", err),
}

View file

@ -1,7 +1,7 @@
use crate::client::ClientWSSink;
use crate::collaborate::all_group::AllCollabGroup;
use crate::collaborate::group_cmd::{GroupCommand, GroupCommandRunner, GroupCommandSender};
use crate::collaborate::permission::CollabAccessControl;
use crate::collaborate::permission::RealtimeAccessControl;
use crate::collaborate::RealtimeMetrics;
use crate::entities::{
ClientMessage, ClientStreamMessage, Connect, Disconnect, Editing, RealtimeMessage, RealtimeUser,
@ -58,7 +58,7 @@ impl<S, U, AC> RealtimeServer<S, U, AC>
where
S: CollabStorage,
U: RealtimeUser,
AC: CollabAccessControl,
AC: RealtimeAccessControl,
{
pub fn new(
storage: Arc<S>,
@ -200,7 +200,7 @@ async fn remove_user<S, U, AC>(
) where
S: CollabStorage,
U: RealtimeUser,
AC: CollabAccessControl,
AC: RealtimeAccessControl,
{
let entry = editing_collab_by_user.remove(user);
if let Some(entry) = entry {
@ -214,7 +214,7 @@ impl<S, U, AC> Actor for RealtimeServer<S, U, AC>
where
S: 'static + Unpin,
U: RealtimeUser + Unpin,
AC: CollabAccessControl + Unpin,
AC: RealtimeAccessControl + Unpin,
{
type Context = Context<Self>;
@ -227,7 +227,7 @@ impl<S, U, AC> Handler<Connect<U>> for RealtimeServer<S, U, AC>
where
U: RealtimeUser + Unpin,
S: CollabStorage + Unpin,
AC: CollabAccessControl + Unpin,
AC: RealtimeAccessControl + Unpin,
{
type Result = ResponseFuture<Result<(), RealtimeError>>;
@ -266,7 +266,7 @@ impl<S, U, AC> Handler<Disconnect<U>> for RealtimeServer<S, U, AC>
where
U: RealtimeUser + Unpin,
S: CollabStorage + Unpin,
AC: CollabAccessControl + Unpin,
AC: RealtimeAccessControl + Unpin,
{
type Result = ResponseFuture<Result<(), RealtimeError>>;
/// Handles the disconnection of a user from the collaboration server.
@ -308,7 +308,7 @@ impl<S, U, AC> Handler<ClientMessage<U>> for RealtimeServer<S, U, AC>
where
U: RealtimeUser + Unpin,
S: CollabStorage + Unpin,
AC: CollabAccessControl + Unpin,
AC: RealtimeAccessControl + Unpin,
{
type Result = ResponseFuture<Result<(), RealtimeError>>;
@ -354,7 +354,7 @@ impl<S, U, AC> Handler<ClientStreamMessage> for RealtimeServer<S, U, AC>
where
U: RealtimeUser + Unpin,
S: CollabStorage + Unpin,
AC: CollabAccessControl + Unpin,
AC: RealtimeAccessControl + Unpin,
{
type Result = ResponseFuture<Result<(), RealtimeError>>;
@ -442,7 +442,7 @@ async fn remove_user_from_group<S, U, AC>(
) where
S: CollabStorage,
U: RealtimeUser,
AC: CollabAccessControl,
AC: RealtimeAccessControl,
{
let _ = groups.remove_user(&editing.object_id, user).await;
if let Some(group) = groups.get_group(&editing.object_id).await {
@ -460,7 +460,7 @@ impl<S, U, AC> actix::Supervised for RealtimeServer<S, U, AC>
where
S: 'static + Unpin,
U: RealtimeUser + Unpin,
AC: CollabAccessControl + Unpin,
AC: RealtimeAccessControl + Unpin,
{
fn restarting(&mut self, _ctx: &mut Context<RealtimeServer<S, U, AC>>) {
warn!("restarting");
@ -518,7 +518,8 @@ impl CollabClientStream {
let can_sink = sink_filter(&cloned_object_id, &msg).await;
if can_sink {
// Send the message to websocket client actor
client_ws_sink.do_send(msg.into());
let rt_msg = msg.into();
client_ws_sink.do_send(rt_msg);
} else {
// when then client is not allowed to receive messages
tokio::time::sleep(Duration::from_secs(2)).await;
@ -540,10 +541,6 @@ impl CollabClientStream {
let _ = tx.send(Ok(msg)).await;
} else {
// when then client is not allowed to send messages
trace!(
"client:{} is not allowed to send messages",
msg.origin().client_user_id().unwrap_or(0)
);
tokio::time::sleep(Duration::from_secs(2)).await;
}
}

View file

@ -0,0 +1,3 @@
-- Add migration script here
ALTER TABLE af_collab_member
ADD COLUMN created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT (NOW());

View file

@ -2,7 +2,7 @@ use crate::api::util::{compress_type_from_header_value, device_id_from_headers};
use crate::api::ws::CollabServerImpl;
use crate::biz;
use crate::biz::workspace;
use crate::biz::workspace::access_control::WorkspaceAccessControl;
use crate::component::auth::jwt::UserUuid;
use crate::domain::compression::{decompress, CompressionType, X_COMPRESSION_TYPE};
use crate::state::AppState;
@ -32,6 +32,8 @@ use shared_entity::response::{AppResponse, JsonAppResponse};
use sqlx::types::uuid;
use tokio::time::{sleep, Instant};
use crate::biz::collab::access_control::CollabAccessControl;
use tokio_stream::StreamExt;
use tokio_tungstenite::tungstenite::Message;
use tracing::{event, instrument};
@ -41,15 +43,18 @@ use validator::Validate;
pub const WORKSPACE_ID_PATH: &str = "workspace_id";
pub const COLLAB_OBJECT_ID_PATH: &str = "object_id";
pub const WORKSPACE_PATTERN: &str = "/api/workspace";
pub const WORKSPACE_MEMBER_PATTERN: &str = "/api/workspace/{workspace_id}/member";
pub const COLLAB_PATTERN: &str = "/api/workspace/{workspace_id}/collab/{object_id}";
pub fn workspace_scope() -> Scope {
web::scope("/api/workspace")
// deprecated, use the api below instead
.service(web::resource("/list").route(web::get().to(list_workspace_handler)))
.service(web::resource("")
.route(web::get().to(list_workspace_handler))
.route(web::post().to(create_workpace_handler))
.route(web::post().to(create_workspace_handler))
.route(web::patch().to(patch_workspace_handler))
)
.service(web::resource("/{workspace_id}")
@ -60,7 +65,7 @@ pub fn workspace_scope() -> Scope {
.service(
web::resource("/{workspace_id}/member")
.route(web::get().to(get_workspace_members_handler))
.route(web::post().to(add_workspace_members_handler))
.route(web::post().to(create_workspace_members_handler))
.route(web::put().to(update_workspace_member_handler))
.route(web::delete().to(remove_workspace_member_handler)),
)
@ -121,7 +126,7 @@ pub fn collab_scope() -> Scope {
// Adds a workspace for user, if success, return the workspace id
#[instrument(skip_all, err)]
async fn create_workpace_handler(
async fn create_workspace_handler(
uuid: UserUuid,
state: Data<AppState>,
create_workspace_param: Json<CreateWorkspaceParam>,
@ -131,12 +136,11 @@ async fn create_workpace_handler(
.workspace_name
.unwrap_or_else(|| format!("workspace_{}", chrono::Utc::now().timestamp()));
let uid = state.users.get_user_uid(&uuid).await?;
let uid = state.user_cache.get_user_uid(&uuid).await?;
let new_workspace = workspace::ops::create_workspace_for_user(
&state.pg_pool,
&state.workspace_access_control,
&state.collab_access_control,
&state.collab_storage,
&state.collab_access_control_storage,
&uuid,
uid,
&workspace_name,
@ -201,7 +205,7 @@ async fn list_workspace_handler(
}
#[instrument(skip(payload, state), err)]
async fn add_workspace_members_handler(
async fn create_workspace_members_handler(
user_uuid: UserUuid,
workspace_id: web::Path<Uuid>,
payload: Json<CreateWorkspaceMembers>,
@ -242,7 +246,7 @@ async fn get_workspace_members_handler(
#[instrument(skip_all, err)]
async fn remove_workspace_member_handler(
user_uuid: UserUuid,
_user_uuid: UserUuid,
payload: Json<WorkspaceMembers>,
state: Data<AppState>,
workspace_id: web::Path<Uuid>,
@ -254,25 +258,13 @@ async fn remove_workspace_member_handler(
.map(|member| member.0)
.collect::<Vec<String>>();
workspace::ops::remove_workspace_members(
&user_uuid,
&state.pg_pool,
&workspace_id,
&member_emails,
&state.workspace_access_control,
)
.await?;
for email in member_emails {
if let Ok(uid) = select_uid_from_email(&state.pg_pool, &email)
.await
.map_err(AppResponseError::from)
{
state
.workspace_access_control
.remove_role(&uid, &workspace_id)
.await?;
}
}
Ok(AppResponse::Ok().into())
}
@ -295,16 +287,19 @@ async fn update_workspace_member_handler(
) -> Result<JsonAppResponse<()>> {
let workspace_id = workspace_id.into_inner();
let changeset = payload.into_inner();
workspace::ops::update_workspace_member(&state.pg_pool, &workspace_id, &changeset).await?;
if let Some(role) = changeset.role {
if changeset.role.is_some() {
let uid = select_uid_from_email(&state.pg_pool, &changeset.email)
.await
.map_err(AppResponseError::from)?;
state
.workspace_access_control
.insert_workspace_role(&uid, &workspace_id, role)
.await?;
workspace::ops::update_workspace_member(
&uid,
&state.pg_pool,
&workspace_id,
&changeset,
&state.workspace_access_control,
)
.await?;
}
Ok(AppResponse::Ok().into())
@ -317,7 +312,7 @@ async fn create_collab_handler(
state: Data<AppState>,
req: HttpRequest,
) -> Result<Json<AppResponse<()>>> {
let uid = state.users.get_user_uid(&user_uuid).await?;
let uid = state.user_cache.get_user_uid(&user_uuid).await?;
let params = match req.headers().get(X_COMPRESSION_TYPE) {
None => serde_json::from_slice::<CreateCollabParams>(&payload).map_err(|err| {
AppError::InvalidRequest(format!(
@ -340,11 +335,14 @@ async fn create_collab_handler(
params.validate().map_err(AppError::from)?;
let object_id = params.object_id.clone();
state.collab_storage.upsert_collab(&uid, params).await?;
state
.collab_access_control
.update_member(&uid, &object_id, AFAccessLevel::FullAccess)
.await;
.update_access_level_policy(&uid, &object_id, AFAccessLevel::FullAccess)
.await?;
state
.collab_access_control_storage
.insert_collab(&uid, params, false)
.await?;
Ok(Json(AppResponse::Ok()))
}
@ -357,7 +355,7 @@ async fn batch_create_collab_handler(
state: Data<AppState>,
req: HttpRequest,
) -> Result<Json<AppResponse<()>>> {
let uid = state.users.get_user_uid(&user_uuid).await?;
let uid = state.user_cache.get_user_uid(&user_uuid).await?;
let mut collab_params_list = vec![];
let workspace_id = workspace_id.into_inner().to_string();
let compress_type = compress_type_from_header_value(req.headers())?;
@ -426,14 +424,14 @@ async fn batch_create_collab_handler(
for params in collab_params_list {
let object_id = params.object_id.clone();
state
.collab_storage
.upsert_collab_with_transaction(&workspace_id, &uid, params, &mut transaction)
.collab_access_control_storage
.insert_or_update_collab(&workspace_id, &uid, params, &mut transaction)
.await?;
state
.collab_access_control
.update_member(&uid, &object_id, AFAccessLevel::FullAccess)
.await;
.update_access_level_policy(&uid, &object_id, AFAccessLevel::FullAccess)
.await?;
}
transaction
@ -451,7 +449,7 @@ async fn create_collab_list_handler(
state: Data<AppState>,
req: HttpRequest,
) -> Result<Json<AppResponse<()>>> {
let uid = state.users.get_user_uid(&user_uuid).await?;
let uid = state.user_cache.get_user_uid(&user_uuid).await?;
let params = match req.headers().get(X_COMPRESSION_TYPE) {
None => BatchCreateCollabParams::from_bytes(&payload).map_err(|err| {
AppError::InvalidRequest(format!(
@ -492,14 +490,14 @@ async fn create_collab_list_handler(
for params in params_list {
let object_id = params.object_id.clone();
state
.collab_storage
.upsert_collab_with_transaction(&workspace_id, &uid, params, &mut transaction)
.collab_access_control_storage
.insert_or_update_collab(&workspace_id, &uid, params, &mut transaction)
.await?;
state
.collab_access_control
.update_member(&uid, &object_id, AFAccessLevel::FullAccess)
.await;
.update_access_level_policy(&uid, &object_id, AFAccessLevel::FullAccess)
.await?;
}
transaction
@ -516,12 +514,12 @@ async fn get_collab_handler(
state: Data<AppState>,
) -> Result<Json<AppResponse<EncodedCollab>>> {
let uid = state
.users
.user_cache
.get_user_uid(&user_uuid)
.await
.map_err(AppResponseError::from)?;
let data = state
.collab_storage
.collab_access_control_storage
.get_collab_encoded(&uid, payload.into_inner(), false)
.await
.map_err(AppResponseError::from)?;
@ -537,7 +535,7 @@ async fn get_collab_snapshot_handler(
) -> Result<Json<AppResponse<SnapshotData>>> {
let (workspace_id, object_id) = path.into_inner();
let data = state
.collab_storage
.collab_access_control_storage
.get_collab_snapshot(&workspace_id.to_string(), &object_id, &payload.snapshot_id)
.await
.map_err(AppResponseError::from)?;
@ -555,12 +553,12 @@ async fn create_collab_snapshot_handler(
let (workspace_id, object_id) = path.into_inner();
let collab_type = payload.into_inner();
let uid = state
.users
.user_cache
.get_user_uid(&user_uuid)
.await
.map_err(AppResponseError::from)?;
let encoded_collab_v1 = state
.collab_storage
.collab_access_control_storage
.get_collab_encoded(
&uid,
QueryCollabParams::new(&object_id, collab_type, &workspace_id),
@ -571,7 +569,7 @@ async fn create_collab_snapshot_handler(
.unwrap();
let meta = state
.collab_storage
.collab_access_control_storage
.create_snapshot(InsertSnapshotParams {
object_id,
workspace_id,
@ -589,7 +587,7 @@ async fn get_all_collab_snapshot_list_handler(
) -> Result<Json<AppResponse<AFSnapshotMetas>>> {
let (_, object_id) = path.into_inner();
let data = state
.collab_storage
.collab_access_control_storage
.get_collab_snapshot_list(&object_id)
.await
.map_err(AppResponseError::from)?;
@ -603,13 +601,13 @@ async fn batch_get_collab_handler(
payload: Json<BatchQueryCollabParams>,
) -> Result<Json<AppResponse<BatchQueryCollabResult>>> {
let uid = state
.users
.user_cache
.get_user_uid(&user_uuid)
.await
.map_err(AppResponseError::from)?;
let result = BatchQueryCollabResult(
state
.collab_storage
.collab_access_control_storage
.batch_get_collab(&uid, payload.into_inner().0)
.await,
);
@ -623,12 +621,12 @@ async fn update_collab_handler(
state: Data<AppState>,
) -> Result<Json<AppResponse<()>>> {
let (params, workspace_id) = payload.into_inner().split();
let uid = state.users.get_user_uid(&user_uuid).await?;
let uid = state.user_cache.get_user_uid(&user_uuid).await?;
let create_params = CreateCollabParams::from((workspace_id.to_string(), params));
state
.collab_storage
.upsert_collab(&uid, create_params)
.collab_access_control_storage
.insert_collab(&uid, create_params, false)
.await?;
Ok(AppResponse::Ok().into())
}
@ -643,13 +641,13 @@ async fn delete_collab_handler(
payload.validate().map_err(AppError::from)?;
let uid = state
.users
.user_cache
.get_user_uid(&user_uuid)
.await
.map_err(AppResponseError::from)?;
state
.collab_storage
.collab_access_control_storage
.delete_collab(&uid, &payload.object_id)
.await
.map_err(AppResponseError::from)?;
@ -663,11 +661,8 @@ async fn add_collab_member_handler(
state: Data<AppState>,
) -> Result<Json<AppResponse<()>>> {
let payload = payload.into_inner();
biz::collab::ops::create_collab_member(&state.pg_pool, &payload).await?;
state
.collab_access_control
.update_member(&payload.uid, &payload.object_id, payload.access_level)
.await;
biz::collab::ops::create_collab_member(&state.pg_pool, &payload, &state.collab_access_control)
.await?;
Ok(Json(AppResponse::Ok()))
}
@ -678,13 +673,13 @@ async fn update_collab_member_handler(
state: Data<AppState>,
) -> Result<Json<AppResponse<()>>> {
let payload = payload.into_inner();
biz::collab::ops::upsert_collab_member(&state.pg_pool, &user_uuid, &payload).await?;
state
.collab_access_control
.update_member(&payload.uid, &payload.object_id, payload.access_level)
.await;
biz::collab::ops::upsert_collab_member(
&state.pg_pool,
&user_uuid,
&payload,
&state.collab_access_control,
)
.await?;
Ok(Json(AppResponse::Ok()))
}
#[instrument(level = "debug", skip(state, payload), err)]
@ -703,11 +698,8 @@ async fn remove_collab_member_handler(
state: Data<AppState>,
) -> Result<Json<AppResponse<()>>> {
let payload = payload.into_inner();
biz::collab::ops::delete_collab_member(&state.pg_pool, &payload).await?;
state
.collab_access_control
.remove_member(&payload.uid, &payload.object_id)
.await;
biz::collab::ops::delete_collab_member(&state.pg_pool, &payload, &state.collab_access_control)
.await?;
Ok(Json(AppResponse::Ok()))
}
@ -733,7 +725,7 @@ async fn post_realtime_message_stream_handler(
// TODO(nathan): after upgrade the client application, then the device_id should not be empty
let device_id = device_id_from_headers(req.headers()).unwrap_or_else(|_| "".to_string());
let uid = state
.users
.user_cache
.get_user_uid(&user_uuid)
.await
.map_err(AppResponseError::from)?;

View file

@ -8,11 +8,11 @@ use std::sync::Arc;
use realtime::client::RealtimeClient;
use realtime::collaborate::RealtimeServer;
use crate::biz::collab::storage::CollabStorageImpl;
use crate::biz::collab::storage::CollabAccessControlStorage;
use crate::biz::user::RealtimeUserImpl;
use crate::component::auth::jwt::{authorization_from_token, UserUuid};
use crate::biz::casbin::CollabAccessControlImpl;
use crate::biz::casbin::RealtimeCollabAccessControlImpl;
use shared_entity::response::AppResponseError;
use std::time::Duration;
use tracing::{info, instrument};
@ -22,8 +22,13 @@ pub fn ws_scope() -> Scope {
}
const MAX_FRAME_SIZE: usize = 65_536; // 64 KiB
pub type CollabServerImpl =
Addr<RealtimeServer<CollabStorageImpl, Arc<RealtimeUserImpl>, CollabAccessControlImpl>>;
pub type CollabServerImpl = Addr<
RealtimeServer<
CollabAccessControlStorage,
Arc<RealtimeUserImpl>,
RealtimeCollabAccessControlImpl,
>,
>;
#[instrument(skip_all, err)]
#[get("/{token}/{device_id}")]
@ -37,7 +42,7 @@ pub async fn establish_ws_connection(
let (token, device_id) = path.into_inner();
let auth = authorization_from_token(token.as_str(), &state)?;
let user_uuid = UserUuid::from_auth(auth)?;
let result = state.users.get_user_uid(&user_uuid).await;
let result = state.user_cache.get_user_uid(&user_uuid).await;
match result {
Ok(uid) => {

View file

@ -6,14 +6,20 @@ use crate::api::workspace::{collab_scope, workspace_scope};
use crate::api::ws::ws_scope;
use crate::biz::casbin::access_control::AccessControl;
use crate::biz::casbin::enforcer_cache::AFEnforcerCacheImpl;
use crate::biz::collab::access_control::CollabHttpAccessControl;
use crate::biz::collab::storage::init_collab_storage;
use crate::biz::casbin::RealtimeCollabAccessControlImpl;
use crate::biz::collab::access_control::{
CollabMiddlewareAccessControl, CollabStorageAccessControlImpl,
};
use crate::biz::collab::cache::CollabCache;
use crate::biz::collab::storage::CollabStorageImpl;
use crate::biz::pg_listener::PgListeners;
use crate::biz::snapshot::SnapshotControl;
use crate::biz::user::RealtimeUserImpl;
use crate::biz::workspace::access_control::WorkspaceHttpAccessControl;
use crate::biz::workspace::access_control::WorkspaceMiddlewareAccessControl;
use crate::component::auth::HEADER_TOKEN;
use crate::config::config::{Config, DatabaseSetting, GoTrueSetting, S3Setting};
use crate::middleware::access_control_mw::WorkspaceAccessControl;
use crate::middleware::access_control_mw::MiddlewareAccessControlTransform;
use crate::middleware::metrics_mw::MetricsMiddleware;
use crate::middleware::request_id::RequestIdMiddleware;
use crate::self_signed::create_self_signed_certificate;
@ -87,21 +93,21 @@ pub async fn run(
.map(|(_, server_key)| Key::from(server_key.expose_secret().as_bytes()))
.unwrap_or_else(Key::generate);
let storage = state.collab_storage.clone();
let access_control = WorkspaceAccessControl::new()
.with_acs(WorkspaceHttpAccessControl {
pg_pool: state.pg_pool.clone(),
access_control: state.workspace_access_control.clone().into(),
})
.with_acs(CollabHttpAccessControl(
let storage = state.collab_access_control_storage.clone();
let access_control = MiddlewareAccessControlTransform::new()
.with_acs(WorkspaceMiddlewareAccessControl::new(
state.pg_pool.clone(),
state.workspace_access_control.clone().into(),
))
.with_acs(CollabMiddlewareAccessControl::new(
state.collab_access_control.clone().into(),
state.collab_cache.clone(),
));
// Initialize metrics that which are registered in the registry.
let realtime_server = RealtimeServer::<_, Arc<RealtimeUserImpl>, _>::new(
storage.clone(),
state.collab_access_control.clone(),
RealtimeCollabAccessControlImpl::new(state.access_control.clone()),
state.metrics.realtime_metrics.clone(),
rt_cmd_recv,
)
@ -119,8 +125,8 @@ pub async fn run(
.build(),
)
// .wrap(DecryptPayloadMiddleware)
.wrap(RequestIdMiddleware)
.wrap(access_control.clone())
.wrap(RequestIdMiddleware)
.app_data(web::JsonConfig::default().limit(5 * 1024 * 1024))
.service(user_scope())
.service(workspace_scope())
@ -189,7 +195,7 @@ pub async fn init_state(config: &Config, rt_cmd_tx: RTCommandSender) -> Result<A
let workspace_member_listener = pg_listeners.subscribe_workspace_member_change();
info!("Setting up access controls...");
let enforce_cache = Arc::new(AFEnforcerCacheImpl::new(redis_client.clone()));
let enforce_cache = AFEnforcerCacheImpl::new(redis_client.clone());
let access_control = AccessControl::new(
pg_pool.clone(),
collab_member_listener,
@ -199,31 +205,39 @@ pub async fn init_state(config: &Config, rt_cmd_tx: RTCommandSender) -> Result<A
)
.await?;
let user_cache = UserCache::new(pg_pool.clone()).await;
let collab_access_control = access_control.new_collab_access_control();
let workspace_access_control = access_control.new_workspace_access_control();
let collab_cache = CollabCache::new(redis_client.clone(), pg_pool.clone());
let collab_storage = Arc::new(
init_collab_storage(
pg_pool.clone(),
redis_client.clone(),
collab_access_control.clone(),
workspace_access_control.clone(),
metrics.collab_metrics.clone(),
rt_cmd_tx,
)
.await,
);
let users = UserCache::new(pg_pool.clone()).await;
let collab_storage_access_control = CollabStorageAccessControlImpl {
collab_access_control: collab_access_control.clone().into(),
workspace_access_control: workspace_access_control.clone().into(),
cache: collab_cache.clone(),
};
let snapshot_control = SnapshotControl::new(
redis_client.clone(),
pg_pool.clone(),
metrics.collab_metrics.clone(),
)
.await;
let collab_storage = Arc::new(CollabStorageImpl::new(
collab_cache.clone(),
collab_storage_access_control,
snapshot_control,
rt_cmd_tx,
));
info!("Application state initialized");
Ok(AppState {
pg_pool,
config: Arc::new(config.clone()),
users: Arc::new(users),
user_cache,
id_gen: Arc::new(RwLock::new(Snowflake::new(1))),
gotrue_client,
redis_client,
collab_storage,
collab_cache,
collab_access_control_storage: collab_storage,
collab_access_control,
workspace_access_control,
bucket_storage,

View file

@ -14,6 +14,7 @@ use anyhow::anyhow;
use sqlx::PgPool;
use redis::{ErrorKind, FromRedisValue, RedisError, RedisResult, RedisWrite, ToRedisArgs, Value};
use std::sync::Arc;
use tokio::sync::broadcast;
@ -43,16 +44,12 @@ impl AccessControl {
collab_listener: broadcast::Receiver<CollabMemberNotification>,
workspace_listener: broadcast::Receiver<WorkspaceMemberNotification>,
access_control_metrics: Arc<AccessControlMetrics>,
enforcer_cache: Arc<dyn AFEnforcerCache>,
enforcer_cache: impl AFEnforcerCache + 'static,
) -> Result<Self, AppError> {
let access_control_model = casbin::DefaultModel::from_str(MODEL_CONF)
.await
.map_err(|e| AppError::Internal(anyhow!("Failed to create access control model: {}", e)))?;
let access_control_adapter = PgAdapter::new(
pg_pool.clone(),
enforcer_cache.clone(),
access_control_metrics.clone(),
);
let access_control_adapter = PgAdapter::new(pg_pool.clone(), access_control_metrics.clone());
let enforcer = casbin::Enforcer::new(access_control_model, access_control_adapter)
.await
.map_err(|e| {
@ -79,7 +76,7 @@ impl AccessControl {
WorkspaceAccessControlImpl::new(self.clone())
}
pub async fn update(
pub async fn update_policy(
&self,
uid: &i64,
obj: &ObjectType<'_>,
@ -88,53 +85,27 @@ impl AccessControl {
if cfg!(feature = "disable_access_control") {
Ok(true)
} else {
self.enforcer.update(uid, obj, act).await
self.enforcer.update_policy(uid, obj, act).await
}
}
pub async fn remove(&self, uid: &i64, obj: &ObjectType<'_>) -> Result<(), AppError> {
pub async fn remove_policy(&self, uid: &i64, obj: &ObjectType<'_>) -> Result<(), AppError> {
if cfg!(feature = "disable_access_control") {
Ok(())
} else {
self.enforcer.remove(uid, obj).await?;
self.enforcer.remove_policy(uid, obj).await?;
Ok(())
}
}
pub async fn enforce<A>(&self, uid: &i64, obj: &ObjectType<'_>, act: A) -> Result<bool, AppError>
where
A: ToCasbinAction,
A: ToACAction,
{
if cfg!(feature = "disable_access_control") {
Ok(true)
} else {
self.enforcer.enforce(uid, obj, act).await
}
}
pub async fn get_access_level(&self, uid: &i64, oid: &str) -> Option<AFAccessLevel> {
if cfg!(feature = "disable_access_control") {
Some(AFAccessLevel::FullAccess)
} else {
let collab_id = ObjectType::Collab(oid);
self
.enforcer
.get_action(uid, &collab_id)
.await
.map(|value| AFAccessLevel::from_action(&value))
}
}
pub async fn get_role(&self, uid: &i64, workspace_id: &str) -> Option<AFRole> {
if cfg!(feature = "disable_access_control") {
Some(AFRole::Owner)
} else {
let workspace_id = ObjectType::Workspace(workspace_id);
self
.enforcer
.get_action(uid, &workspace_id)
.await
.map(|value| AFRole::from_action(&value))
self.enforcer.enforce_policy(uid, obj, act).await
}
}
}
@ -231,8 +202,8 @@ pub enum ActionType {
Level(AFAccessLevel),
}
impl ToCasbinAction for ActionType {
fn to_action(&self) -> String {
impl ToACAction for ActionType {
fn to_action(&self) -> &str {
match self {
ActionType::Role(role) => role.to_action(),
ActionType::Level(level) => level.to_action(),
@ -241,26 +212,41 @@ impl ToCasbinAction for ActionType {
}
/// Represents the actions that can be performed on objects.
#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum Action {
Read,
Write,
Delete,
}
impl ToCasbinAction for Action {
fn to_action(&self) -> String {
match self {
Action::Read => "read".to_owned(),
Action::Write => "write".to_owned(),
Action::Delete => "delete".to_owned(),
impl ToRedisArgs for Action {
fn write_redis_args<W>(&self, out: &mut W)
where
W: ?Sized + RedisWrite,
{
self.to_action().write_redis_args(out)
}
}
impl FromRedisValue for Action {
fn from_redis_value(v: &Value) -> RedisResult<Self> {
let s: String = FromRedisValue::from_redis_value(v)?;
match s.as_str() {
"read" => Ok(Action::Read),
"write" => Ok(Action::Write),
"delete" => Ok(Action::Delete),
_ => Err(RedisError::from((ErrorKind::TypeError, "invalid action"))),
}
}
}
impl From<Method> for Action {
fn from(method: Method) -> Self {
Self::from(&method)
impl ToACAction for Action {
fn to_action(&self) -> &str {
match self {
Action::Read => "read",
Action::Write => "write",
Action::Delete => "delete",
}
}
}
@ -275,31 +261,40 @@ impl From<&Method> for Action {
}
}
pub trait ToCasbinAction {
fn to_action(&self) -> String;
pub trait ToACAction {
fn to_action(&self) -> &str;
}
pub trait FromCasbinAction {
pub trait FromACAction {
fn from_action(action: &str) -> Self;
}
impl ToCasbinAction for AFAccessLevel {
fn to_action(&self) -> String {
i32::from(self).to_string()
impl ToACAction for AFAccessLevel {
fn to_action(&self) -> &str {
match self {
AFAccessLevel::ReadOnly => "10",
AFAccessLevel::ReadAndComment => "20",
AFAccessLevel::ReadAndWrite => "30",
AFAccessLevel::FullAccess => "50",
}
}
}
impl FromCasbinAction for AFAccessLevel {
impl FromACAction for AFAccessLevel {
fn from_action(action: &str) -> Self {
Self::from(action)
}
}
impl ToCasbinAction for AFRole {
fn to_action(&self) -> String {
i32::from(self).to_string()
impl ToACAction for AFRole {
fn to_action(&self) -> &str {
match self {
AFRole::Owner => "1",
AFRole::Member => "2",
AFRole::Guest => "3",
}
}
}
impl FromCasbinAction for AFRole {
impl FromACAction for AFRole {
fn from_action(action: &str) -> Self {
Self::from(action)
}

View file

@ -1,5 +1,5 @@
use crate::biz::casbin::access_control::{Action, ObjectType, ToCasbinAction};
use crate::biz::casbin::enforcer::{AFEnforcerCache, ActionCacheKey};
use crate::biz::casbin::access_control::{Action, ObjectType, ToACAction};
use async_trait::async_trait;
use crate::biz::casbin::metrics::AccessControlMetrics;
@ -18,32 +18,24 @@ use sqlx::PgPool;
use std::sync::Arc;
use std::time::Instant;
use tokio_stream::StreamExt;
use tracing::error;
/// Implementation of [`casbin::Adapter`] for access control authorisation.
/// Access control policies that are managed by workspace and collab CRUD.
pub struct PgAdapter {
pg_pool: PgPool,
access_control_metrics: Arc<AccessControlMetrics>,
enforce_cache: Arc<dyn AFEnforcerCache>,
}
impl PgAdapter {
pub fn new(
pg_pool: PgPool,
enforce_cache: Arc<dyn AFEnforcerCache>,
access_control_metrics: Arc<AccessControlMetrics>,
) -> Self {
pub fn new(pg_pool: PgPool, access_control_metrics: Arc<AccessControlMetrics>) -> Self {
Self {
pg_pool,
enforce_cache,
access_control_metrics,
}
}
}
async fn load_collab_policies(
enforce_cache: &Arc<dyn AFEnforcerCache>,
mut stream: BoxStream<'_, sqlx::Result<AFCollabMemerAccessLevelRow>>,
) -> Result<Vec<Vec<String>>> {
let mut policies: Vec<Vec<String>> = Vec::new();
@ -52,14 +44,12 @@ async fn load_collab_policies(
let uid = member_access_lv.uid;
let object_type = ObjectType::Collab(&member_access_lv.oid);
let action = member_access_lv.access_level.to_action();
if let Err(err) = enforce_cache
.set_action(&ActionCacheKey::new(&uid, &object_type), action.clone())
.await
{
error!("{}", err)
}
let policy = [uid.to_string(), object_type.to_object_id(), action].to_vec();
let policy = [
uid.to_string(),
object_type.to_object_id(),
action.to_string(),
]
.to_vec();
policies.push(policy);
}
@ -67,7 +57,6 @@ async fn load_collab_policies(
}
async fn load_workspace_policies(
enforce_cache: &Arc<dyn AFEnforcerCache>,
mut stream: BoxStream<'_, sqlx::Result<AFWorkspaceMemberPermRow>>,
) -> Result<Vec<Vec<String>>> {
let mut policies: Vec<Vec<String>> = Vec::new();
@ -77,14 +66,12 @@ async fn load_workspace_policies(
let workspace_id = member_permission.workspace_id.to_string();
let object_type = ObjectType::Workspace(&workspace_id);
let action = member_permission.role.to_action();
if let Err(err) = enforce_cache
.set_action(&ActionCacheKey::new(&uid, &object_type), action.clone())
.await
{
error!("{}", err);
}
let policy = [uid.to_string(), object_type.to_object_id(), action].to_vec();
let policy = [
uid.to_string(),
object_type.to_object_id(),
action.to_string(),
]
.to_vec();
policies.push(policy);
}
@ -96,15 +83,13 @@ impl Adapter for PgAdapter {
async fn load_policy(&mut self, model: &mut dyn Model) -> Result<()> {
let start = Instant::now();
let workspace_member_perm_stream = select_workspace_member_perm_stream(&self.pg_pool);
let workspace_policies =
load_workspace_policies(&self.enforce_cache, workspace_member_perm_stream).await?;
let workspace_policies = load_workspace_policies(workspace_member_perm_stream).await?;
// Policy definition `p` of type `p`. See `model.conf`
model.add_policies("p", "p", workspace_policies);
let collab_member_access_lv_stream = select_collab_member_access_level(&self.pg_pool);
let collab_policies =
load_collab_policies(&self.enforce_cache, collab_member_access_lv_stream).await?;
let collab_policies = load_collab_policies(collab_member_access_lv_stream).await?;
// Policy definition `p` of type `p`. See `model.conf`
model.add_policies("p", "p", collab_policies);
@ -117,7 +102,7 @@ impl Adapter for PgAdapter {
AFAccessLevel::FullAccess,
];
let mut grouping_policies = Vec::new();
for level in af_access_levels {
for level in &af_access_levels {
// All levels can read
grouping_policies.push([level.to_action(), Action::Read.to_action()].to_vec());
if level.can_write() {
@ -129,7 +114,7 @@ impl Adapter for PgAdapter {
}
let af_roles = [AFRole::Owner, AFRole::Member, AFRole::Guest];
for role in af_roles {
for role in &af_roles {
match role {
AFRole::Owner => {
grouping_policies.push([role.to_action(), Action::Delete.to_action()].to_vec());
@ -145,6 +130,11 @@ impl Adapter for PgAdapter {
},
}
}
let grouping_policies = grouping_policies
.into_iter()
.map(|actions| actions.into_iter().map(|a| a.to_string()).collect())
.collect();
// Grouping definition `g` of type `g`. See `model.conf`
model.add_policies("g", "g", grouping_policies);
self

View file

@ -1,12 +1,10 @@
use crate::biz::casbin::access_control::{AccessControl, Action};
use crate::biz::casbin::access_control::{ActionType, ObjectType};
use actix_http::Method;
use crate::biz::collab::access_control::CollabAccessControl;
use app_error::AppError;
use async_trait::async_trait;
use database_entity::dto::AFAccessLevel;
use realtime::collaborate::CollabAccessControl;
use realtime::collaborate::RealtimeAccessControl;
use tracing::instrument;
#[derive(Clone)]
@ -18,42 +16,31 @@ impl CollabAccessControlImpl {
pub fn new(access_control: AccessControl) -> Self {
Self { access_control }
}
#[instrument(level = "info", skip_all)]
pub async fn update_member(&self, uid: &i64, oid: &str, access_level: AFAccessLevel) {
let _ = self
.access_control
.update(
uid,
&ObjectType::Collab(oid),
&ActionType::Level(access_level),
)
.await;
}
pub async fn remove_member(&self, uid: &i64, oid: &str) {
let _ = self
.access_control
.remove(uid, &ObjectType::Collab(oid))
.await;
}
}
#[async_trait]
impl CollabAccessControl for CollabAccessControlImpl {
async fn get_collab_access_level(&self, uid: &i64, oid: &str) -> Result<AFAccessLevel, AppError> {
async fn enforce_action(&self, uid: &i64, oid: &str, action: Action) -> Result<bool, AppError> {
self
.access_control
.get_access_level(uid, oid)
.enforce(uid, &ObjectType::Collab(oid), action)
.await
.ok_or_else(|| {
AppError::RecordNotFound(format!(
"can't find the access level for user:{} of {} in cache",
uid, oid
))
})
}
#[instrument(level = "trace", skip_all)]
async fn insert_collab_access_level(
async fn enforce_access_level(
&self,
uid: &i64,
oid: &str,
access_level: AFAccessLevel,
) -> Result<bool, AppError> {
self
.access_control
.enforce(uid, &ObjectType::Collab(oid), access_level)
.await
}
#[instrument(level = "info", skip_all)]
async fn update_access_level_policy(
&self,
uid: &i64,
oid: &str,
@ -61,25 +48,35 @@ impl CollabAccessControl for CollabAccessControlImpl {
) -> Result<(), AppError> {
self
.access_control
.update(uid, &ObjectType::Collab(oid), &ActionType::Level(level))
.update_policy(uid, &ObjectType::Collab(oid), &ActionType::Level(level))
.await?;
Ok(())
}
async fn can_access_http_method(
&self,
uid: &i64,
oid: &str,
method: &Method,
) -> Result<bool, AppError> {
let action = Action::from(method);
#[instrument(level = "info", skip_all)]
async fn remove_access_level(&self, uid: &i64, oid: &str) -> Result<(), AppError> {
self
.access_control
.enforce(uid, &ObjectType::Collab(oid), action)
.await
.remove_policy(uid, &ObjectType::Collab(oid))
.await?;
Ok(())
}
}
#[derive(Clone)]
pub struct RealtimeCollabAccessControlImpl {
access_control: AccessControl,
}
impl RealtimeCollabAccessControlImpl {
pub fn new(access_control: AccessControl) -> Self {
Self { access_control }
}
}
#[async_trait]
impl RealtimeAccessControl for RealtimeCollabAccessControlImpl {
async fn can_send_collab_update(&self, uid: &i64, oid: &str) -> Result<bool, AppError> {
if cfg!(feature = "disable_access_control") {
Ok(true)

View file

@ -1,6 +1,5 @@
use crate::biz::casbin::access_control::{
ActionType, ObjectType, ToCasbinAction, POLICY_FIELD_INDEX_ACTION, POLICY_FIELD_INDEX_OBJECT,
POLICY_FIELD_INDEX_USER,
ActionType, ObjectType, ToACAction, POLICY_FIELD_INDEX_OBJECT, POLICY_FIELD_INDEX_USER,
};
use anyhow::anyhow;
use app_error::AppError;
@ -14,32 +13,33 @@ use std::time::Duration;
use crate::biz::casbin::metrics::AccessControlMetrics;
use tokio::sync::RwLock;
use tokio::sync::{Mutex, RwLock};
use tokio::time::interval;
use tracing::{error, event, trace};
use tracing::{error, event, instrument, trace};
#[async_trait]
pub trait AFEnforcerCache: Send + Sync {
async fn set_enforcer_result(&self, key: &PolicyCacheKey, value: bool) -> Result<(), AppError>;
async fn get_enforcer_result(&self, key: &PolicyCacheKey) -> Option<bool>;
async fn remove_enforcer_result(&self, key: &PolicyCacheKey);
async fn set_action(&self, key: &ActionCacheKey, value: String) -> Result<(), AppError>;
async fn get_action(&self, key: &ActionCacheKey) -> Option<String>;
async fn remove_action(&self, key: &ActionCacheKey);
async fn set_enforcer_result(
&mut self,
key: &PolicyCacheKey,
value: bool,
) -> Result<(), AppError>;
async fn get_enforcer_result(&mut self, key: &PolicyCacheKey) -> Option<bool>;
async fn remove_enforcer_result(&mut self, key: &PolicyCacheKey);
}
pub const ENFORCER_METRICS_TICK_INTERVAL: Duration = Duration::from_secs(30);
pub struct AFEnforcer {
enforcer: RwLock<Enforcer>,
cache: Arc<dyn AFEnforcerCache>,
cache: Arc<Mutex<dyn AFEnforcerCache>>,
metrics_cal: MetricsCal,
}
impl AFEnforcer {
pub fn new(
enforcer: Enforcer,
cache: Arc<dyn AFEnforcerCache>,
cache: impl AFEnforcerCache + 'static,
metrics: Arc<AccessControlMetrics>,
) -> Self {
let metrics_cal = MetricsCal::new();
@ -63,85 +63,150 @@ impl AFEnforcer {
Self {
enforcer: RwLock::new(enforcer),
cache,
cache: Arc::new(Mutex::new(cache)),
metrics_cal,
}
}
/// Update permission for a user.
/// Update policy for a user.
/// If the policy is already exist, then it will return Ok(false).
///
/// [`ObjectType::Workspace`] has to be paired with [`ActionType::Role`],
/// [`ObjectType::Collab`] has to be paired with [`ActionType::Level`],
pub async fn update(
#[instrument(level = "debug", skip_all, err)]
pub async fn update_policy(
&self,
uid: &i64,
obj: &ObjectType<'_>,
act: &ActionType,
) -> Result<bool, AppError> {
validate_obj_action(obj, act)?;
let policy = vec![uid.to_string(), obj.to_object_id(), act.to_action()];
let policy = vec![
uid.to_string(),
obj.to_object_id(),
act.to_action().to_string(),
];
let policy_key = PolicyCacheKey::new(&policy);
// if the policy is already in the cache, return. Only update the policy if it's not in the cache.
if let Some(value) = self.cache.get_enforcer_result(&policy_key).await {
if let Some(value) = self
.cache
.lock()
.await
.get_enforcer_result(&policy_key)
.await
{
return Ok(value);
}
// only one policy per user per object. So remove the old policy and add the new one.
let mut write_guard = self.enforcer.write().await;
let _remove_policies = self
.remove_with_enforcer(uid, obj, &mut write_guard)
.await?;
let result = write_guard
.add_policy(policy)
.await
.map_err(|e| AppError::Internal(anyhow!("fail to add policy: {e:?}")));
trace!(
"[access control]: add policy:{} => {:?}",
policy_key.0,
result
);
drop(write_guard);
let object_key = ActionCacheKey::new(uid, obj);
match &result {
Ok(value) => {
trace!("[access control]: add policy:{} => {}", policy_key.0, value);
if let Err(err) = self.cache.set_action(&object_key, act.to_action()).await {
error!("{}", err);
}
},
Err(err) => {
trace!(
"[access control]: fail to add policy:{} => {:?}",
policy_key.0,
err
);
},
}
result
}
/// Returns policies that match the filter.
pub async fn remove(
pub async fn remove_policy(
&self,
uid: &i64,
object_type: &ObjectType<'_>,
) -> Result<Vec<Vec<String>>, AppError> {
) -> Result<(), AppError> {
let mut enforcer = self.enforcer.write().await;
self
.remove_with_enforcer(uid, object_type, &mut enforcer)
.await
}
pub async fn remove_with_enforcer(
#[instrument(level = "debug", skip_all)]
pub async fn enforce_policy<A>(
&self,
uid: &i64,
obj: &ObjectType<'_>,
act: A,
) -> Result<bool, AppError>
where
A: ToACAction,
{
self
.metrics_cal
.total_read_enforce_result
.fetch_add(1, Ordering::Relaxed);
// create policy request
let policy_request = vec![
uid.to_string(),
obj.to_object_id(),
act.to_action().to_string(),
];
let policy_key = PolicyCacheKey::new(&policy_request);
// if the policy is already in the cache, return. Only update the policy if it's not in the cache.
if let Some(value) = self
.cache
.lock()
.await
.get_enforcer_result(&policy_key)
.await
{
self
.metrics_cal
.read_enforce_result_from_cache
.fetch_add(1, Ordering::Relaxed);
return Ok(value);
}
// Perform the action and capture the result or error
let action_result = self.enforcer.read().await.enforce(policy_request);
match &action_result {
Ok(result) => trace!(
"[access control]: enforce policy:{} with result:{}",
policy_key.0,
result
),
Err(e) => trace!(
"[access control]: enforce policy:{} with error: {:?}",
policy_key.0,
e
),
}
// Convert the action result into the original method's result type, handling errors as before
let result = action_result.map_err(|e| AppError::Internal(anyhow!("enforce: {e:?}")))?;
if let Err(err) = self
.cache
.lock()
.await
.set_enforcer_result(&policy_key, result)
.await
{
error!("{}", err)
}
Ok(result)
}
#[inline]
async fn remove_with_enforcer(
&self,
uid: &i64,
object_type: &ObjectType<'_>,
enforcer: &mut Enforcer,
) -> Result<Vec<Vec<String>>, AppError> {
) -> Result<(), AppError> {
let policies_for_user_on_object =
policies_for_user_with_given_object(uid, object_type, enforcer).await;
// if there are no policies for the user on the object, return early.
if policies_for_user_on_object.is_empty() {
return Ok(vec![]);
return Ok(());
}
event!(
@ -151,78 +216,20 @@ impl AFEnforcer {
uid,
policies_for_user_on_object
);
debug_assert!(
policies_for_user_on_object.len() == 1,
"only one policy per user per object"
);
enforcer
.remove_policies(policies_for_user_on_object.clone())
.await
.map_err(|e| AppError::Internal(anyhow!("error enforce: {e:?}")))?;
let object_key = ActionCacheKey::new(uid, object_type);
self.cache.remove_action(&object_key).await;
let mut cache_lock_guard = self.cache.lock().await;
for policy in &policies_for_user_on_object {
self
.cache
cache_lock_guard
.remove_enforcer_result(&PolicyCacheKey::new(policy))
.await;
}
drop(cache_lock_guard);
Ok(policies_for_user_on_object)
}
pub async fn enforce<A>(&self, uid: &i64, obj: &ObjectType<'_>, act: A) -> Result<bool, AppError>
where
A: ToCasbinAction,
{
self
.metrics_cal
.total_read_enforce_result
.fetch_add(1, Ordering::Relaxed);
let policy = vec![uid.to_string(), obj.to_object_id(), act.to_action()];
let policy_key = PolicyCacheKey::new(&policy);
if let Some(value) = self.cache.get_enforcer_result(&policy_key).await {
self
.metrics_cal
.read_enforce_result_from_cache
.fetch_add(1, Ordering::Relaxed);
return Ok(value);
}
let read_guard = self.enforcer.read().await;
let policies_for_object =
read_guard.get_filtered_policy(POLICY_FIELD_INDEX_OBJECT, vec![obj.to_object_id()]);
if policies_for_object.is_empty() {
return Ok(true);
}
let result = read_guard
.enforce(policy)
enforcer
.remove_policies(policies_for_user_on_object)
.await
.map_err(|e| AppError::Internal(anyhow!("error enforce: {e:?}")))?;
drop(read_guard);
trace!("[access control]: policy:{} => {}", policy_key.0, result);
if let Err(err) = self.cache.set_enforcer_result(&policy_key, result).await {
error!("{}", err)
}
Ok(result)
}
pub async fn get_action(&self, uid: &i64, object_type: &ObjectType<'_>) -> Option<String> {
let object_key = ActionCacheKey::new(uid, object_type);
if let Some(value) = self.cache.get_action(&object_key).await {
return Some(value.clone());
}
// There should only be one entry per user per object, which is enforced in [AccessControl], so just take one using next.
let policies =
policies_for_user_with_given_object(uid, object_type, &*self.enforcer.read().await).await;
let action = policies.first()?[POLICY_FIELD_INDEX_ACTION].clone();
trace!("cache action: {}:{}", object_key.0, action.clone());
let _ = self.cache.set_action(&object_key, action.clone()).await;
Some(action)
Ok(())
}
}
@ -252,28 +259,6 @@ impl AsRef<str> for PolicyCacheKey {
}
}
#[derive(Debug, Hash, Eq, PartialEq)]
pub struct ActionCacheKey(String);
impl AsRef<str> for ActionCacheKey {
fn as_ref(&self) -> &str {
&self.0
}
}
impl Deref for ActionCacheKey {
type Target = str;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl ActionCacheKey {
pub(crate) fn new(uid: &i64, object_type: &ObjectType<'_>) -> Self {
Self(format!("{}:{}", uid, object_type.to_object_id()))
}
}
fn validate_obj_action(obj: &ObjectType<'_>, act: &ActionType) -> Result<(), AppError> {
match (obj, act) {
(ObjectType::Workspace(_), ActionType::Role(_))

View file

@ -1,87 +1,52 @@
use crate::biz::casbin::enforcer::{AFEnforcerCache, ActionCacheKey, PolicyCacheKey};
use crate::biz::casbin::enforcer::{AFEnforcerCache, PolicyCacheKey};
use crate::state::RedisClient;
use redis::AsyncCommands;
use anyhow::anyhow;
use app_error::AppError;
use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::error;
/// Expire time for cache in seconds. When the cache is expired, the enforcer will re-evaluate the policy.
const EXPIRE_TIME: u64 = 60 * 60 * 3;
const EXPIRE_TIME: u64 = 60 * 60 * 24 * 3;
#[derive(Clone)]
pub struct AFEnforcerCacheImpl {
redis_client: Arc<Mutex<RedisClient>>,
redis_client: RedisClient,
}
impl AFEnforcerCacheImpl {
pub fn new(redis_client: RedisClient) -> Self {
Self {
redis_client: Arc::new(Mutex::new(redis_client)),
}
Self { redis_client }
}
}
#[async_trait]
impl AFEnforcerCache for AFEnforcerCacheImpl {
async fn set_enforcer_result(&self, key: &PolicyCacheKey, value: bool) -> Result<(), AppError> {
async fn set_enforcer_result(
&mut self,
key: &PolicyCacheKey,
value: bool,
) -> Result<(), AppError> {
self
.redis_client
.lock()
.await
.set_ex::<&str, bool, ()>(key, value, EXPIRE_TIME)
.await
.map_err(|e| AppError::Internal(anyhow!("Failed to set enforcer result in redis: {}", e)))
}
async fn get_enforcer_result(&self, key: &PolicyCacheKey) -> Option<bool> {
async fn get_enforcer_result(&mut self, key: &PolicyCacheKey) -> Option<bool> {
self
.redis_client
.lock()
.await
.get::<&str, Option<bool>>(key.as_ref())
.await
.ok()?
}
async fn remove_enforcer_result(&self, key: &PolicyCacheKey) {
if let Err(err) = self
.redis_client
.lock()
.await
.del::<&str, ()>(key.as_ref())
.await
{
async fn remove_enforcer_result(&mut self, key: &PolicyCacheKey) {
if let Err(err) = self.redis_client.del::<&str, ()>(key.as_ref()).await {
error!("Failed to remove enforcer result from redis: {}", err);
}
}
async fn set_action(&self, key: &ActionCacheKey, value: String) -> Result<(), AppError> {
self
.redis_client
.lock()
.await
.set_ex::<&str, String, ()>(key, value, EXPIRE_TIME)
.await
.map_err(|e| AppError::Internal(anyhow!("Failed to set action in redis: {}", e)))
}
async fn get_action(&self, key: &ActionCacheKey) -> Option<String> {
self
.redis_client
.lock()
.await
.get::<&str, Option<String>>(key.as_ref())
.await
.ok()?
}
async fn remove_action(&self, key: &ActionCacheKey) {
if let Err(err) = self.redis_client.lock().await.del::<&str, ()>(key).await {
error!("Failed to remove action from cache: {}", err);
}
}
}

View file

@ -8,7 +8,7 @@ pub mod pg_listen;
mod workspace_ac;
pub use collab_ac::CollabAccessControlImpl;
pub use collab_ac::RealtimeCollabAccessControlImpl;
pub use enforcer::AFEnforcerCache;
pub use enforcer::ActionCacheKey;
pub use enforcer::PolicyCacheKey;
pub use workspace_ac::WorkspaceAccessControlImpl;

View file

@ -25,7 +25,7 @@ pub(crate) fn spawn_listen_on_collab_member_change(
let permission_row = select_permission(&pg_pool, &member_row.permission_id).await;
if let Ok(Some(row)) = permission_row {
if let Err(err) = enforcer
.update(
.update_policy(
&member_row.uid,
&ObjectType::Collab(&member_row.oid),
&ActionType::Level(row.access_level),
@ -44,7 +44,7 @@ pub(crate) fn spawn_listen_on_collab_member_change(
},
CollabMemberAction::DELETE => {
if let (Some(oid), Some(uid)) = (change.old_oid(), change.old_uid()) {
if let Err(err) = enforcer.remove(uid, &ObjectType::Collab(oid)).await {
if let Err(err) = enforcer.remove_policy(uid, &ObjectType::Collab(oid)).await {
warn!(
"Failed to remove the user:{} collab{} access control, error: {}",
uid, oid, err
@ -72,7 +72,7 @@ pub(crate) fn spawn_listen_on_workspace_member_change(
},
Some(member_row) => {
if let Err(err) = enforcer
.update(
.update_policy(
&member_row.uid,
&ObjectType::Workspace(&member_row.workspace_id.to_string()),
&ActionType::Role(AFRole::from(member_row.role_id as i32)),
@ -90,7 +90,7 @@ pub(crate) fn spawn_listen_on_workspace_member_change(
None => warn!("The workspace member change can't be None when the action is DELETE"),
Some(member_row) => {
if let Err(err) = enforcer
.remove(
.remove_policy(
&member_row.uid,
&ObjectType::Workspace(&member_row.workspace_id.to_string()),
)

View file

@ -1,12 +1,9 @@
use crate::biz::casbin::access_control::AccessControl;
use crate::biz::casbin::access_control::{AccessControl, Action};
use crate::biz::casbin::access_control::{ActionType, ObjectType};
use crate::biz::workspace::access_control::WorkspaceAccessControl;
use app_error::AppError;
use async_trait::async_trait;
use database_entity::dto::AFRole;
use sqlx::{Executor, Postgres};
use database_entity::dto::{AFAccessLevel, AFRole};
use tracing::instrument;
use uuid::Uuid;
@ -23,51 +20,67 @@ impl WorkspaceAccessControlImpl {
#[async_trait]
impl WorkspaceAccessControl for WorkspaceAccessControlImpl {
async fn get_workspace_role<'a, E>(
async fn enforce_role(
&self,
uid: &i64,
workspace_id: &Uuid,
_executor: E,
) -> Result<AFRole, AppError>
where
E: Executor<'a, Database = Postgres>,
{
let workspace_id = workspace_id.to_string();
workspace_id: &str,
role: AFRole,
) -> Result<bool, AppError> {
self
.access_control
.get_role(uid, &workspace_id)
.enforce(uid, &ObjectType::Workspace(workspace_id), role)
.await
}
async fn enforce_action(
&self,
uid: &i64,
workspace_id: &str,
action: Action,
) -> Result<bool, AppError> {
self
.access_control
.enforce(uid, &ObjectType::Workspace(workspace_id), action)
.await
.ok_or_else(|| {
AppError::RecordNotFound(format!(
"can't find the role for user:{} workspace:{}",
uid, workspace_id
))
})
}
#[instrument(level = "info", skip_all)]
async fn insert_workspace_role(
async fn insert_role(
&self,
uid: &i64,
workspace_id: &Uuid,
role: AFRole,
) -> Result<(), AppError> {
let _ = self
let access_level = AFAccessLevel::from(&role);
self
.access_control
.update(
.update_policy(
uid,
&ObjectType::Workspace(&workspace_id.to_string()),
&ActionType::Role(role),
)
.await?;
self
.access_control
.update_policy(
uid,
&ObjectType::Collab(&workspace_id.to_string()),
&ActionType::Level(access_level),
)
.await?;
Ok(())
}
#[instrument(level = "info", skip_all)]
async fn remove_role(&self, uid: &i64, workspace_id: &Uuid) -> Result<(), AppError> {
let _ = self
self
.access_control
.remove(uid, &ObjectType::Workspace(&workspace_id.to_string()))
.remove_policy(uid, &ObjectType::Workspace(&workspace_id.to_string()))
.await?;
self
.access_control
.remove_policy(uid, &ObjectType::Collab(&workspace_id.to_string()))
.await?;
Ok(())
}

View file

@ -1,23 +1,95 @@
use crate::api::workspace::COLLAB_PATTERN;
use crate::biz::casbin::access_control::Action;
use crate::biz::workspace::access_control::WorkspaceAccessControl;
use crate::middleware::access_control_mw::{AccessResource, HttpAccessControlService};
use actix_router::{Path, Url};
use crate::middleware::access_control_mw::{AccessResource, MiddlewareAccessControl};
use actix_router::{Path, ResourceDef, Url};
use actix_web::http::Method;
use app_error::AppError;
use async_trait::async_trait;
use database::collab::CollabStorageAccessControl;
use database_entity::dto::{AFAccessLevel, AFRole};
use realtime::collaborate::CollabAccessControl;
use sqlx::{Executor, Postgres};
use std::sync::Arc;
use tracing::{error, instrument};
use uuid::Uuid;
#[derive(Clone)]
pub struct CollabHttpAccessControl<AC: CollabAccessControl>(pub Arc<AC>);
use crate::biz::collab::cache::CollabCache;
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{instrument, trace};
#[async_trait]
impl<AC> HttpAccessControlService for CollabHttpAccessControl<AC>
pub trait CollabAccessControl: Sync + Send + 'static {
async fn enforce_action(&self, uid: &i64, oid: &str, action: Action) -> Result<bool, AppError>;
async fn enforce_access_level(
&self,
uid: &i64,
oid: &str,
access_level: AFAccessLevel,
) -> Result<bool, AppError>;
/// Return the access level of the user in the collab
async fn update_access_level_policy(
&self,
uid: &i64,
oid: &str,
level: AFAccessLevel,
) -> Result<(), AppError>;
async fn remove_access_level(&self, uid: &i64, oid: &str) -> Result<(), AppError>;
}
#[derive(Clone)]
pub struct CollabMiddlewareAccessControl<AC: CollabAccessControl> {
pub access_control: Arc<AC>,
collab_cache: CollabCache,
skip_resources: Vec<(Method, ResourceDef)>,
require_access_levels: Vec<(ResourceDef, HashMap<Method, AFAccessLevel>)>,
}
impl<AC> CollabMiddlewareAccessControl<AC>
where
AC: CollabAccessControl,
{
pub fn new(access_control: Arc<AC>, collab_cache: CollabCache) -> Self {
Self {
skip_resources: vec![
// Skip access control when trying to create a collab
(Method::POST, ResourceDef::new(COLLAB_PATTERN)),
],
require_access_levels: vec![(
ResourceDef::new(COLLAB_PATTERN),
[
// Only the user with FullAccess can delete the collab
(Method::DELETE, AFAccessLevel::FullAccess),
]
.into(),
)],
access_control,
collab_cache,
}
}
fn should_skip(&self, method: &Method, path: &Path<Url>) -> bool {
self.skip_resources.iter().any(|(m, r)| {
if m != method {
return false;
}
r.is_match(path.as_str())
})
}
fn require_access_level(&self, method: &Method, path: &Path<Url>) -> Option<AFAccessLevel> {
self.require_access_levels.iter().find_map(|(r, roles)| {
if r.is_match(path.as_str()) {
roles.get(method).cloned()
} else {
None
}
})
}
}
#[async_trait]
impl<AC> MiddlewareAccessControl for CollabMiddlewareAccessControl<AC>
where
AC: CollabAccessControl,
{
@ -25,31 +97,54 @@ where
AccessResource::Collab
}
async fn check_workspace_permission(
#[instrument(name = "check_collab_permission", level = "trace", skip_all, err)]
async fn check_resource_permission(
&self,
_workspace_id: &Uuid,
_uid: &i64,
_method: Method,
) -> Result<(), AppError> {
error!("Shouldn't call CollabHttpAccessControl here");
Ok(())
}
#[instrument(level = "debug", skip_all, err)]
async fn check_collab_permission(
&self,
oid: &str,
uid: &i64,
oid: &str,
method: Method,
_path: &Path<Url>,
path: &Path<Url>,
) -> Result<(), AppError> {
if self.0.can_access_http_method(uid, oid, &method).await? {
if self.should_skip(&method, path) {
trace!("Skip access control for the request");
return Ok(());
}
let collab_exists = self.collab_cache.is_exist(oid).await?;
if !collab_exists {
return Err(AppError::RecordNotFound(format!(
"Collab not exist in db. {}",
oid
)));
}
let access_level = self.require_access_level(&method, path);
let result = match access_level {
None => {
self
.access_control
.enforce_action(uid, oid, Action::from(&method))
.await?
},
Some(access_level) => {
self
.access_control
.enforce_access_level(uid, oid, access_level)
.await?
},
};
if result {
Ok(())
} else {
Err(AppError::NotEnoughPermissions(format!(
"Not enough permissions to access the collab: {} with http method: {}",
oid, method
)))
Err(AppError::NotEnoughPermissions {
user: uid.to_string(),
action: format!(
"access collab:{} with url:{}, method:{}",
oid,
path.as_str(),
method
),
})
}
}
}
@ -58,6 +153,7 @@ where
pub struct CollabStorageAccessControlImpl<CollabAC, WorkspaceAC> {
pub(crate) collab_access_control: Arc<CollabAC>,
pub(crate) workspace_access_control: Arc<WorkspaceAC>,
pub(crate) cache: CollabCache,
}
#[async_trait]
@ -67,37 +163,7 @@ where
CollabAC: CollabAccessControl,
WorkspaceAC: WorkspaceAccessControl,
{
async fn get_or_refresh_collab_access_level<'a, E: Executor<'a, Database = Postgres>>(
&self,
uid: &i64,
oid: &str,
executor: E,
) -> Result<AFAccessLevel, AppError> {
let access_level_result = self
.collab_access_control
.get_collab_access_level(uid, oid)
.await;
if let Ok(level) = access_level_result {
return Ok(level);
}
// Safe unwrap, we know it's an Err here
let err = access_level_result.unwrap_err();
if err.is_record_not_found() {
let member = database::collab::select_collab_member(uid, oid, executor).await?;
self
.collab_access_control
.insert_collab_access_level(uid, oid, member.permission.access_level)
.await?;
Ok(member.permission.access_level)
} else {
Err(err)
}
}
async fn cache_collab_access_level(
async fn update_policy(
&self,
uid: &i64,
oid: &str,
@ -105,19 +171,49 @@ where
) -> Result<(), AppError> {
self
.collab_access_control
.insert_collab_access_level(uid, oid, level)
.update_access_level_policy(uid, oid, level)
.await
}
async fn get_user_workspace_role<'a, E: Executor<'a, Database = Postgres>>(
&self,
uid: &i64,
workspace_id: &str,
executor: E,
) -> Result<AFRole, AppError> {
async fn enforce_read_collab(&self, uid: &i64, oid: &str) -> Result<bool, AppError> {
let collab_exists = self.cache.is_exist(oid).await?;
if !collab_exists {
return Err(AppError::RecordNotFound(format!(
"Collab not exist in db. {}",
oid
)));
}
self
.collab_access_control
.enforce_action(uid, oid, Action::Read)
.await
}
async fn enforce_write_collab(&self, uid: &i64, oid: &str) -> Result<bool, AppError> {
let collab_exists = self.cache.is_exist(oid).await?;
if !collab_exists {
return Err(AppError::RecordNotFound(format!(
"Collab not exist in db. {}",
oid
)));
}
self
.collab_access_control
.enforce_action(uid, oid, Action::Write)
.await
}
async fn enforce_delete(&self, uid: &i64, oid: &str) -> Result<bool, AppError> {
self
.collab_access_control
.enforce_access_level(uid, oid, AFAccessLevel::FullAccess)
.await
}
async fn enforce_write_workspace(&self, uid: &i64, workspace_id: &str) -> Result<bool, AppError> {
self
.workspace_access_control
.get_workspace_role(uid, &workspace_id.parse()?, executor)
.enforce_role(uid, workspace_id, AFRole::Owner)
.await
}
}

160
src/biz/collab/cache.rs Normal file
View file

@ -0,0 +1,160 @@
use crate::biz::collab::disk_cache::CollabDiskCache;
use crate::biz::collab::mem_cache::CollabMemCache;
use crate::biz::collab::storage::check_encoded_collab_data;
use app_error::AppError;
use collab::core::collab_plugin::EncodedCollab;
use crate::state::RedisClient;
use database_entity::dto::{CollabParams, QueryCollab, QueryCollabParams, QueryCollabResult};
use futures_util::{stream, StreamExt};
use itertools::{Either, Itertools};
use sqlx::{PgPool, Transaction};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tracing::{event, Level};
#[derive(Clone)]
pub struct CollabCache {
disk_cache: CollabDiskCache,
mem_cache: CollabMemCache,
hits: Arc<AtomicU64>,
total_attempts: Arc<AtomicU64>,
}
impl CollabCache {
pub fn new(redis_client: RedisClient, pg_pool: PgPool) -> Self {
let mem_cache = CollabMemCache::new(redis_client.clone());
let disk_cache = CollabDiskCache::new(pg_pool.clone());
Self {
disk_cache,
mem_cache,
hits: Arc::new(AtomicU64::new(0)),
total_attempts: Arc::new(AtomicU64::new(0)),
}
}
pub async fn get_collab_encoded(
&self,
uid: &i64,
params: QueryCollabParams,
) -> Result<EncodedCollab, AppError> {
self.total_attempts.fetch_add(1, Ordering::Relaxed);
// Attempt to retrieve encoded collab from memory cache, falling back to disk cache if necessary.
if let Some(encoded_collab) = self
.mem_cache
.get_encode_collab(&params.inner.object_id)
.await
{
event!(
Level::DEBUG,
"Get encoded collab:{} from cache",
params.object_id
);
self.hits.fetch_add(1, Ordering::Relaxed);
return Ok(encoded_collab);
}
// Retrieve from disk cache as fallback. After retrieval, the value is inserted into the memory cache.
let object_id = params.object_id.clone();
let encoded_collab = self.disk_cache.get_collab_encoded(uid, params).await?;
self
.mem_cache
.insert_encode_collab(object_id, &encoded_collab)
.await;
Ok(encoded_collab)
}
pub async fn batch_get_encode_collab(
&self,
uid: &i64,
queries: Vec<QueryCollab>,
) -> HashMap<String, QueryCollabResult> {
let mut results = HashMap::new();
// 1. Processes valid queries against the in-memory cache to retrieve cached values.
// - Queries not found in the cache are earmarked for disk retrieval.
let (disk_queries, values_from_mem_cache): (Vec<_>, HashMap<_, _>) = stream::iter(queries)
.then(|params| async move {
match self
.mem_cache
.get_encode_collab_bytes(&params.object_id)
.await
{
None => Either::Left(params),
Some(data) => Either::Right((
params.object_id.clone(),
QueryCollabResult::Success {
encode_collab_v1: data,
},
)),
}
})
.collect::<Vec<_>>()
.await
.into_iter()
.partition_map(|either| either);
results.extend(values_from_mem_cache);
// 2. Retrieves remaining values from the disk cache for queries not satisfied by the memory cache.
// - These values are then merged into the final result set.
let values_from_disk_cache = self.disk_cache.batch_get_collab(uid, disk_queries).await;
results.extend(values_from_disk_cache);
results
}
pub async fn insert_collab_encoded(
&self,
workspace_id: &str,
uid: &i64,
params: CollabParams,
transaction: &mut Transaction<'_, sqlx::Postgres>,
) -> Result<(), AppError> {
if let Err(err) = check_encoded_collab_data(&params.object_id, &params.encoded_collab_v1) {
let msg = format!(
"Can not decode the data into collab:{}, {}",
params.object_id, err
);
return Err(AppError::InvalidRequest(msg));
}
let object_id = params.object_id.clone();
let encoded_collab = params.encoded_collab_v1.clone();
self
.disk_cache
.upsert_collab_with_transaction(workspace_id, uid, params, transaction)
.await?;
self
.mem_cache
.insert_encode_collab_bytes(object_id, encoded_collab)
.await;
Ok(())
}
pub fn get_hit_rate(&self) -> f64 {
let hits = self.hits.load(Ordering::Relaxed) as f64;
let total_attempts = self.total_attempts.load(Ordering::Relaxed) as f64;
if total_attempts == 0.0 {
0.0
} else {
hits / total_attempts
}
}
pub async fn remove_collab(&self, object_id: &str) -> Result<(), AppError> {
self.mem_cache.remove_encode_collab(object_id).await?;
self.disk_cache.delete_collab(object_id).await?;
Ok(())
}
pub async fn is_exist(&self, oid: &str) -> Result<bool, AppError> {
let is_exist = self.disk_cache.is_exist(oid).await?;
Ok(is_exist)
}
pub fn pg_pool(&self) -> &sqlx::PgPool {
&self.disk_cache.pg_pool
}
}

View file

@ -0,0 +1,100 @@
use anyhow::anyhow;
use app_error::AppError;
use collab::core::collab_plugin::EncodedCollab;
use database::collab::{
batch_select_collab_blob, delete_collab, insert_into_af_collab, is_collab_exists,
select_blob_from_af_collab, DatabaseResult,
};
use database_entity::dto::{CollabParams, QueryCollab, QueryCollabParams, QueryCollabResult};
use sqlx::{PgPool, Transaction};
use std::collections::HashMap;
use std::time::Duration;
use tokio::time::sleep;
use tracing::{event, Level};
#[derive(Clone)]
pub struct CollabDiskCache {
pub pg_pool: PgPool,
}
impl CollabDiskCache {
pub fn new(pg_pool: PgPool) -> Self {
Self { pg_pool }
}
pub async fn is_exist(&self, object_id: &str) -> DatabaseResult<bool> {
let is_exist = is_collab_exists(object_id, &self.pg_pool).await?;
Ok(is_exist)
}
pub async fn upsert_collab_with_transaction(
&self,
workspace_id: &str,
uid: &i64,
params: CollabParams,
transaction: &mut Transaction<'_, sqlx::Postgres>,
) -> DatabaseResult<()> {
insert_into_af_collab(transaction, uid, workspace_id, &params).await?;
Ok(())
}
pub async fn get_collab_encoded(
&self,
_uid: &i64,
params: QueryCollabParams,
) -> Result<EncodedCollab, AppError> {
event!(
Level::INFO,
"Get encoded collab:{} from disk",
params.object_id
);
const MAX_ATTEMPTS: usize = 3;
let mut attempts = 0;
loop {
let result =
select_blob_from_af_collab(&self.pg_pool, &params.collab_type, &params.object_id).await;
match result {
Ok(data) => {
return tokio::task::spawn_blocking(move || {
EncodedCollab::decode_from_bytes(&data).map_err(|err| {
AppError::Internal(anyhow!("fail to decode data to EncodedCollab: {:?}", err))
})
})
.await?;
},
Err(e) => {
// Handle non-retryable errors immediately
if matches!(e, sqlx::Error::RowNotFound) {
let msg = format!("Can't find the row for query: {:?}", params);
return Err(AppError::RecordNotFound(msg));
}
// Increment attempts and retry if below MAX_ATTEMPTS and the error is retryable
if attempts < MAX_ATTEMPTS - 1 && matches!(e, sqlx::Error::PoolTimedOut) {
attempts += 1;
sleep(Duration::from_millis(500 * attempts as u64)).await;
continue;
} else {
return Err(e.into());
}
},
}
}
}
pub async fn batch_get_collab(
&self,
_uid: &i64,
queries: Vec<QueryCollab>,
) -> HashMap<String, QueryCollabResult> {
batch_select_collab_blob(&self.pg_pool, queries).await
}
pub async fn delete_collab(&self, object_id: &str) -> DatabaseResult<()> {
delete_collab(&self.pg_pool, object_id).await?;
Ok(())
}
}

View file

@ -1,7 +1,9 @@
use crate::state::RedisClient;
use collab::core::collab_plugin::EncodedCollab;
use redis::AsyncCommands;
use std::sync::atomic::{AtomicU64, Ordering};
use anyhow::anyhow;
use app_error::AppError;
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::{error, trace};
@ -9,33 +11,31 @@ use tracing::{error, trace};
#[derive(Clone)]
pub struct CollabMemCache {
redis_client: Arc<Mutex<RedisClient>>,
hits: Arc<AtomicU64>,
total_attempts: Arc<AtomicU64>,
}
impl CollabMemCache {
pub fn new(redis_client: RedisClient) -> Self {
Self {
redis_client: Arc::new(Mutex::new(redis_client)),
hits: Arc::new(AtomicU64::new(0)),
total_attempts: Arc::new(AtomicU64::new(0)),
}
}
pub async fn remove_encode_collab(&self, object_id: &str) {
if let Err(err) = self
pub async fn remove_encode_collab(&self, object_id: &str) -> Result<(), AppError> {
self
.redis_client
.lock()
.await
.del::<&str, ()>(object_id)
.await
{
error!("Failed to remove encoded collab from redis: {:?}", err);
}
.map_err(|err| {
AppError::Internal(anyhow!(
"Failed to remove encoded collab from redis: {:?}",
err
))
})
}
pub async fn get_encode_collab_bytes(&self, object_id: &str) -> Option<Vec<u8>> {
self.total_attempts.fetch_add(1, Ordering::Relaxed);
let result = self
.redis_client
.lock()
@ -43,10 +43,7 @@ impl CollabMemCache {
.get::<_, Option<Vec<u8>>>(object_id)
.await;
match result {
Ok(bytes) => {
self.hits.fetch_add(1, Ordering::Relaxed);
bytes
},
Ok(bytes) => bytes,
Err(err) => {
error!("Failed to get encoded collab from redis: {:?}", err);
None
@ -101,15 +98,4 @@ impl CollabMemCache {
.set_ex::<_, Vec<u8>, ()>(object_id, bytes, 259200)
.await
}
pub fn get_hit_rate(&self) -> f64 {
let hits = self.hits.load(Ordering::Relaxed) as f64;
let total_attempts = self.total_attempts.load(Ordering::Relaxed) as f64;
if total_attempts == 0.0 {
0.0
} else {
hits / total_attempts
}
}
}

View file

@ -1,5 +1,7 @@
pub mod access_control;
mod mem_cache;
pub mod cache;
pub mod disk_cache;
pub mod mem_cache;
pub mod metrics;
pub mod ops;
pub mod storage;

View file

@ -8,6 +8,8 @@ use database_entity::dto::{
QueryCollabMembers, UpdateCollabMemberParams,
};
use crate::biz::collab::access_control::CollabAccessControl;
use sqlx::{types::Uuid, PgPool};
use tracing::{event, trace};
use validator::Validate;
@ -28,6 +30,7 @@ pub async fn delete_collab(
pub async fn create_collab_member(
pg_pool: &PgPool,
params: &InsertCollabMemberParams,
collab_access_control: &impl CollabAccessControl,
) -> Result<(), AppError> {
params.validate()?;
@ -65,6 +68,10 @@ pub async fn create_collab_member(
)
.await?;
collab_access_control
.update_access_level_policy(&params.uid, &params.object_id, params.access_level)
.await?;
transaction
.commit()
.await
@ -76,6 +83,7 @@ pub async fn upsert_collab_member(
pg_pool: &PgPool,
_user_uuid: &Uuid,
params: &UpdateCollabMemberParams,
collab_access_control: &impl CollabAccessControl,
) -> Result<(), AppError> {
params.validate()?;
let mut transaction = pg_pool
@ -90,6 +98,10 @@ pub async fn upsert_collab_member(
)));
}
collab_access_control
.update_access_level_policy(&params.uid, &params.object_id, params.access_level)
.await?;
database::collab::insert_collab_member(
params.uid,
&params.object_id,
@ -118,15 +130,29 @@ pub async fn get_collab_member(
pub async fn delete_collab_member(
pg_pool: &PgPool,
params: &CollabMemberIdentify,
collab_access_control: &impl CollabAccessControl,
) -> Result<(), AppError> {
params.validate()?;
let mut transaction = pg_pool
.begin()
.await
.context("acquire transaction to remove collab member")?;
event!(
tracing::Level::DEBUG,
"Deleting member:{} from {}",
params.uid,
params.object_id
);
database::collab::delete_collab_member(params.uid, &params.object_id, pg_pool).await?;
database::collab::delete_collab_member(params.uid, &params.object_id, &mut transaction).await?;
collab_access_control
.remove_access_level(&params.uid, &params.object_id)
.await?;
transaction
.commit()
.await
.context("fail to commit the transaction to remove collab member")?;
Ok(())
}
pub async fn get_collab_member_list(

View file

@ -1,7 +1,6 @@
use crate::biz::casbin::{CollabAccessControlImpl, WorkspaceAccessControlImpl};
use crate::biz::collab::access_control::CollabStorageAccessControlImpl;
use crate::biz::collab::mem_cache::CollabMemCache;
use crate::state::RedisClient;
use anyhow::Context;
use app_error::AppError;
use async_trait::async_trait;
@ -11,81 +10,53 @@ use collab::core::origin::CollabOrigin;
use collab::preclude::Collab;
use database::collab::{
is_collab_exists, CollabStorage, CollabStorageAccessControl, CollabStoragePgImpl, DatabaseResult,
WriteConfig,
is_collab_exists, CollabStorage, CollabStorageAccessControl, DatabaseResult,
};
use database_entity::dto::{
AFAccessLevel, AFSnapshotMeta, AFSnapshotMetas, CollabParams, CreateCollabParams,
InsertSnapshotParams, QueryCollab, QueryCollabParams, QueryCollabResult, SnapshotData,
};
use futures::stream::{self, StreamExt};
use itertools::{Either, Itertools};
use sqlx::{PgPool, Transaction};
use sqlx::Transaction;
use std::collections::HashMap;
use std::ops::DerefMut;
use std::time::Duration;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::oneshot;
use tokio::time::timeout;
use crate::biz::collab::metrics::CollabMetrics;
use crate::biz::collab::cache::CollabCache;
use crate::biz::snapshot::SnapshotControl;
use realtime::collaborate::{RTCommand, RTCommandSender};
use tracing::{error, event, instrument, Level};
use tracing::{error, instrument};
use validator::Validate;
pub type CollabStorageImpl = CollabStoragePostgresImpl<
pub type CollabAccessControlStorage = CollabStorageImpl<
CollabStorageAccessControlImpl<CollabAccessControlImpl, WorkspaceAccessControlImpl>,
>;
pub async fn init_collab_storage(
pg_pool: PgPool,
redis_client: RedisClient,
collab_access_control: CollabAccessControlImpl,
workspace_access_control: WorkspaceAccessControlImpl,
collab_metrics: Arc<CollabMetrics>,
realtime_server_command_sender: RTCommandSender,
) -> CollabStorageImpl {
let access_control = CollabStorageAccessControlImpl {
collab_access_control: collab_access_control.into(),
workspace_access_control: workspace_access_control.into(),
};
let disk_cache = CollabStoragePgImpl::new(pg_pool.clone());
let mem_cache = CollabMemCache::new(redis_client.clone());
let snapshot_control = SnapshotControl::new(redis_client, pg_pool, collab_metrics).await;
CollabStoragePostgresImpl::new(
disk_cache,
mem_cache,
access_control,
snapshot_control,
realtime_server_command_sender,
)
}
/// A wrapper around the actual storage implementation that provides access control and caching.
#[derive(Clone)]
pub struct CollabStoragePostgresImpl<AC> {
disk_cache: CollabStoragePgImpl,
mem_cache: CollabMemCache,
pub struct CollabStorageImpl<AC> {
cache: CollabCache,
/// access control for collab object. Including read/write
access_control: AC,
snapshot_control: SnapshotControl,
rt_cmd: RTCommandSender,
}
impl<AC> CollabStoragePostgresImpl<AC>
impl<AC> CollabStorageImpl<AC>
where
AC: CollabStorageAccessControl,
{
pub fn new(
disk_cache: CollabStoragePgImpl,
mem_cache: CollabMemCache,
cache: CollabCache,
access_control: AC,
snapshot_control: SnapshotControl,
rt_cmd_sender: RTCommandSender,
) -> Self {
Self {
disk_cache,
mem_cache,
cache,
access_control,
snapshot_control,
rt_cmd: rt_cmd_sender,
@ -107,39 +78,34 @@ where
// If the collab already exists, check if the user has enough permissions to update collab
let can_write = self
.access_control
.get_or_refresh_collab_access_level(uid, &params.object_id, transaction.deref_mut())
.await
.context(format!(
"Can't find the access level when user:{} try to insert collab",
uid
))?
.can_write();
.enforce_write_collab(uid, &params.object_id)
.await?;
if !can_write {
return Err(AppError::NotEnoughPermissions(format!(
"user:{} doesn't have enough permissions to update collab {}",
uid, params.object_id
)));
return Err(AppError::NotEnoughPermissions {
user: uid.to_string(),
action: format!("update collab:{}", params.object_id),
});
}
} else {
// If the collab doesn't exist, check if the user has enough permissions to create collab.
// If the user is the owner or member of the workspace, the user can create collab.
let can_write_workspace = self
.access_control
.get_user_workspace_role(uid, workspace_id, transaction.deref_mut())
.await?
.can_create_collab();
.enforce_write_workspace(uid, workspace_id)
.await?;
if !can_write_workspace {
return Err(AppError::NotEnoughPermissions(format!(
"user:{} doesn't have enough permissions to insert collab {}",
uid, params.object_id
)));
return Err(AppError::NotEnoughPermissions {
user: uid.to_string(),
action: format!("write workspace:{}", workspace_id),
});
}
// Cache the access level if the user has enough permissions to create collab.
self
.access_control
.cache_collab_access_level(uid, &params.object_id, AFAccessLevel::FullAccess)
.update_policy(uid, &params.object_id, AFAccessLevel::FullAccess)
.await?;
}
@ -181,29 +147,36 @@ where
}
#[async_trait]
impl<AC> CollabStorage for CollabStoragePostgresImpl<AC>
impl<AC> CollabStorage for CollabStorageImpl<AC>
where
AC: CollabStorageAccessControl,
{
fn config(&self) -> &WriteConfig {
self.disk_cache.config()
}
fn encode_collab_mem_hit_rate(&self) -> f64 {
self.mem_cache.get_hit_rate()
self.cache.get_hit_rate()
}
async fn upsert_collab(&self, uid: &i64, params: CreateCollabParams) -> DatabaseResult<()> {
async fn insert_collab(
&self,
uid: &i64,
params: CreateCollabParams,
is_new: bool,
) -> DatabaseResult<()> {
let mut transaction = self
.disk_cache
.pg_pool
.cache
.pg_pool()
.begin()
.await
.context("acquire transaction to upsert collab")
.map_err(AppError::from)?;
if is_new {
self
.access_control
.update_policy(uid, &params.object_id, AFAccessLevel::FullAccess)
.await?;
}
let (params, workspace_id) = params.split();
self
.upsert_collab_with_transaction(&workspace_id, uid, params, &mut transaction)
.insert_or_update_collab(&workspace_id, uid, params, &mut transaction)
.await?;
transaction
.commit()
@ -215,7 +188,7 @@ where
#[instrument(level = "trace", skip(self, params), oid = %params.oid, err)]
#[allow(clippy::blocks_in_if_conditions)]
async fn upsert_collab_with_transaction(
async fn insert_or_update_collab(
&self,
workspace_id: &str,
uid: &i64,
@ -227,26 +200,10 @@ where
.check_collab_permission(workspace_id, uid, &params, transaction)
.await?;
// Check if the data can be decoded into collab
if let Err(err) = check_encoded_collab_data(&params.object_id, &params.encoded_collab_v1) {
let msg = format!(
"Can not decode the data into collab:{}, {}",
params.object_id, err
);
return Err(AppError::InvalidRequest(msg));
}
let object_id = params.object_id.clone();
let encoded_collab = params.encoded_collab_v1.clone();
self
.disk_cache
.upsert_collab_with_transaction(workspace_id, uid, params, transaction)
.cache
.insert_collab_encoded(workspace_id, uid, params, transaction)
.await?;
self
.mem_cache
.insert_encode_collab_bytes(object_id, encoded_collab)
.await;
Ok(())
}
@ -257,11 +214,20 @@ where
is_collab_init: bool,
) -> DatabaseResult<EncodedCollab> {
params.validate()?;
self
// Check if the user has enough permissions to access the collab
let can_read = self
.access_control
.get_or_refresh_collab_access_level(uid, &params.object_id, &self.disk_cache.pg_pool)
.enforce_read_collab(uid, &params.object_id)
.await?;
if !can_read {
return Err(AppError::NotEnoughPermissions {
user: uid.to_string(),
action: format!("read collab:{}", params.object_id),
});
}
// Early return if editing collab is initialized, as it indicates no need to query further.
if !is_collab_init {
// Attempt to retrieve encoded collab from the editing collab
@ -270,28 +236,8 @@ where
}
}
// Attempt to retrieve encoded collab from memory cache, falling back to disk cache if necessary.
if let Some(encoded_collab) = self
.mem_cache
.get_encode_collab(&params.inner.object_id)
.await
{
event!(
Level::DEBUG,
"Get encoded collab:{} from cache",
params.object_id
);
return Ok(encoded_collab);
}
// Retrieve from disk cache as fallback. After retrieval, the value is inserted into the memory cache.
let object_id = params.object_id.clone();
let encoded_collab = self.disk_cache.get_collab_encoded(uid, params).await?;
self
.mem_cache
.insert_encode_collab(object_id, &encoded_collab)
.await;
Ok(encoded_collab)
let encode_collab = self.cache.get_collab_encoded(uid, params).await?;
Ok(encode_collab)
}
async fn batch_get_collab(
@ -299,7 +245,7 @@ where
uid: &i64,
queries: Vec<QueryCollab>,
) -> HashMap<String, QueryCollabResult> {
// 1. Partition queries based on validation into valid queries and errors (with associated error messages).
// Partition queries based on validation into valid queries and errors (with associated error messages).
let (valid_queries, mut results): (Vec<_>, HashMap<_, _>) =
queries
.into_iter()
@ -313,65 +259,27 @@ where
)),
});
// 2. Processes valid queries against the in-memory cache to retrieve cached values.
// - Queries not found in the cache are earmarked for disk retrieval.
let (disk_queries, values_from_mem_cache): (Vec<_>, HashMap<_, _>) =
stream::iter(valid_queries)
.then(|params| async move {
match self
.mem_cache
.get_encode_collab_bytes(&params.object_id)
.await
{
None => Either::Left(params),
Some(data) => Either::Right((
params.object_id.clone(),
QueryCollabResult::Success {
encode_collab_v1: data,
},
)),
}
})
.collect::<Vec<_>>()
.await
.into_iter()
.partition_map(|either| either);
results.extend(values_from_mem_cache);
// 3. Retrieves remaining values from the disk cache for queries not satisfied by the memory cache.
// - These values are then merged into the final result set.
let values_from_disk_cache = self.disk_cache.batch_get_collab(uid, disk_queries).await;
results.extend(values_from_disk_cache);
results.extend(self.cache.batch_get_encode_collab(uid, valid_queries).await);
results
}
async fn delete_collab(&self, uid: &i64, object_id: &str) -> DatabaseResult<()> {
if !self
.access_control
.get_or_refresh_collab_access_level(uid, object_id, &self.disk_cache.pg_pool)
.await
.context(format!(
"Can't find the access level when user:{} try to delete {}",
uid, object_id
))?
.can_delete()
{
return Err(AppError::NotEnoughPermissions(format!(
"user:{} doesn't have enough permissions to delete collab {}",
uid, object_id
)));
if !self.access_control.enforce_delete(uid, object_id).await? {
return Err(AppError::NotEnoughPermissions {
user: uid.to_string(),
action: format!("delete collab:{}", object_id),
});
}
self.mem_cache.remove_encode_collab(object_id).await;
self.disk_cache.delete_collab(uid, object_id).await
self.cache.remove_collab(object_id).await?;
Ok(())
}
async fn should_create_snapshot(&self, oid: &str) -> bool {
self.disk_cache.should_create_snapshot(oid).await
self.snapshot_control.should_create_snapshot(oid).await
}
async fn create_snapshot(&self, params: InsertSnapshotParams) -> DatabaseResult<AFSnapshotMeta> {
self.disk_cache.create_snapshot(params).await
self.snapshot_control.create_snapshot(params).await
}
async fn queue_snapshot(&self, params: InsertSnapshotParams) -> DatabaseResult<()> {
@ -384,18 +292,14 @@ where
object_id: &str,
snapshot_id: &i64,
) -> DatabaseResult<SnapshotData> {
match self
self
.snapshot_control
.get_snapshot(workspace_id, object_id)
.get_snapshot(workspace_id, object_id, snapshot_id)
.await
{
None => self.disk_cache.get_collab_snapshot(snapshot_id).await,
Some(data) => Ok(data),
}
}
async fn get_collab_snapshot_list(&self, oid: &str) -> DatabaseResult<AFSnapshotMetas> {
self.disk_cache.get_collab_snapshot_list(oid).await
self.snapshot_control.get_collab_snapshot_list(oid).await
}
}

View file

@ -4,14 +4,18 @@ use crate::biz::snapshot::queue::PendingQueue;
use crate::state::RedisClient;
use app_error::AppError;
use async_stream::stream;
use database::collab::{create_snapshot_and_maintain_limit, COLLAB_SNAPSHOT_LIMIT};
use database_entity::dto::{InsertSnapshotParams, SnapshotData};
use database::collab::{
create_snapshot_and_maintain_limit, get_all_collab_snapshot_meta, select_snapshot,
should_create_snapshot, DatabaseResult, COLLAB_SNAPSHOT_LIMIT,
};
use database_entity::dto::{AFSnapshotMeta, AFSnapshotMetas, InsertSnapshotParams, SnapshotData};
use futures_util::StreamExt;
use sqlx::PgPool;
use crate::biz::collab::storage::check_encoded_collab_data;
use anyhow::anyhow;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
@ -23,7 +27,7 @@ use validator::Validate;
pub type SnapshotCommandReceiver = tokio::sync::mpsc::Receiver<SnapshotCommand>;
pub type SnapshotCommandSender = tokio::sync::mpsc::Sender<SnapshotCommand>;
pub const SNAPSHOT_TICK_INTERVAL: Duration = Duration::from_secs(10);
pub const SNAPSHOT_TICK_INTERVAL: Duration = Duration::from_secs(2);
pub enum SnapshotCommand {
InsertSnapshot(InsertSnapshotParams),
@ -34,6 +38,7 @@ pub enum SnapshotCommand {
pub struct SnapshotControl {
cache: SnapshotCache,
command_sender: SnapshotCommandSender,
pg_pool: PgPool,
}
impl SnapshotControl {
@ -46,7 +51,7 @@ impl SnapshotControl {
let (command_sender, rx) = tokio::sync::mpsc::channel(2000);
let cache = SnapshotCache::new(redis_client);
let runner = SnapshotCommandRunner::new(pg_pool, cache.clone(), rx);
let runner = SnapshotCommandRunner::new(pg_pool.clone(), cache.clone(), rx);
tokio::spawn(runner.run());
let cloned_sender = command_sender.clone();
@ -71,9 +76,67 @@ impl SnapshotControl {
Self {
cache,
command_sender,
pg_pool,
}
}
pub async fn should_create_snapshot(&self, oid: &str) -> bool {
if oid.is_empty() {
warn!("unexpected empty object id when checking should_create_snapshot");
return false;
}
should_create_snapshot(oid, &self.pg_pool)
.await
.unwrap_or(false)
}
pub async fn create_snapshot(
&self,
params: InsertSnapshotParams,
) -> DatabaseResult<AFSnapshotMeta> {
params.validate()?;
debug!("create snapshot for object:{}", params.object_id);
match self.pg_pool.try_begin().await {
Ok(Some(transaction)) => {
let meta = create_snapshot_and_maintain_limit(
transaction,
&params.workspace_id,
&params.object_id,
&params.encoded_collab_v1,
COLLAB_SNAPSHOT_LIMIT,
)
.await?;
Ok(meta)
},
_ => Err(AppError::Internal(anyhow!(
"fail to acquire transaction to create snapshot for object:{}",
params.object_id,
))),
}
}
pub async fn get_collab_snapshot(&self, snapshot_id: &i64) -> DatabaseResult<SnapshotData> {
match select_snapshot(&self.pg_pool, snapshot_id).await? {
None => Err(AppError::RecordNotFound(format!(
"Can't find the snapshot with id:{}",
snapshot_id
))),
Some(row) => Ok(SnapshotData {
object_id: row.oid,
encoded_collab_v1: row.blob,
workspace_id: row.workspace_id.to_string(),
}),
}
}
/// Returns list of snapshots for given object_id in descending order of creation time.
pub async fn get_collab_snapshot_list(&self, oid: &str) -> DatabaseResult<AFSnapshotMetas> {
let metas = get_all_collab_snapshot_meta(&self.pg_pool, oid).await?;
Ok(metas)
}
pub async fn queue_snapshot(&self, params: InsertSnapshotParams) -> Result<(), AppError> {
params.validate()?;
trace!("Queuing snapshot for {}", params.object_id);
@ -85,14 +148,23 @@ impl SnapshotControl {
Ok(())
}
pub async fn get_snapshot(&self, workspace_id: &str, object_id: &str) -> Option<SnapshotData> {
pub async fn get_snapshot(
&self,
workspace_id: &str,
object_id: &str,
snapshot_id: &i64,
) -> Result<SnapshotData, AppError> {
let key = SnapshotKey::from_object_id(object_id);
let encoded_collab_v1 = self.cache.try_get(&key.0).await.ok()??;
Some(SnapshotData {
encoded_collab_v1,
workspace_id: workspace_id.to_string(),
object_id: object_id.to_string(),
})
let encoded_collab_v1 = self.cache.try_get(&key.0).await.unwrap_or(None);
match encoded_collab_v1 {
None => self.get_collab_snapshot(snapshot_id).await,
Some(encoded_collab_v1) => Ok(SnapshotData {
encoded_collab_v1,
workspace_id: workspace_id.to_string(),
object_id: object_id.to_string(),
}),
}
}
}

View file

@ -21,7 +21,7 @@ use uuid::Uuid;
use workspace_template::document::get_started::GetStartedDocumentTemplate;
use workspace_template::{WorkspaceTemplate, WorkspaceTemplateBuilder};
use super::collab::storage::CollabStorageImpl;
use super::collab::storage::CollabAccessControlStorage;
/// Verify the token from the gotrue server and create the user if it is a new user
/// Return true if the user is a new user
@ -59,7 +59,7 @@ pub async fn verify_token(access_token: &str, state: &AppState) -> Result<bool,
// It's essential to cache the user's role because subsequent actions will rely on this cached information.
state
.workspace_access_control
.insert_workspace_role(
.insert_role(
&new_uid,
&Uuid::parse_str(&workspace_id).unwrap(),
AFRole::Owner,
@ -72,7 +72,7 @@ pub async fn verify_token(access_token: &str, state: &AppState) -> Result<bool,
&workspace_id,
&mut txn,
vec![GetStartedDocumentTemplate],
&state.collab_storage,
&state.collab_access_control_storage,
)
.await?;
}
@ -91,8 +91,7 @@ pub async fn initialize_workspace_for_user<T>(
workspace_id: &str,
txn: &mut Transaction<'_, sqlx::Postgres>,
templates: Vec<T>,
// state: &AppState,
collab_storage: &Arc<CollabStorageImpl>,
collab_storage: &Arc<CollabAccessControlStorage>,
) -> Result<(), AppError>
where
T: WorkspaceTemplate + Send + Sync + 'static,
@ -111,7 +110,7 @@ where
.map_err(|err| AppError::Internal(anyhow::Error::from(err)))?;
collab_storage
.upsert_collab_with_transaction(
.insert_or_update_collab(
workspace_id,
&uid,
CollabParams {

View file

@ -1,6 +1,6 @@
#![allow(unused)]
use crate::component::auth::jwt::UserUuid;
use crate::middleware::access_control_mw::{AccessResource, HttpAccessControlService};
use crate::middleware::access_control_mw::{AccessResource, MiddlewareAccessControl};
use actix_http::Method;
use async_trait::async_trait;
use database::user::select_uid_from_uuid;
@ -9,7 +9,10 @@ use sqlx::{Executor, PgPool, Postgres};
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use actix_router::{Path, Url};
use crate::api::workspace::{WORKSPACE_MEMBER_PATTERN, WORKSPACE_PATTERN};
use crate::biz::casbin::access_control::Action;
use crate::state::UserCache;
use actix_router::{Path, ResourceDef, Url};
use anyhow::anyhow;
use app_error::AppError;
use database_entity::dto::AFRole;
@ -20,32 +23,86 @@ use uuid::Uuid;
#[async_trait]
pub trait WorkspaceAccessControl: Send + Sync + 'static {
async fn get_workspace_role<'a, E>(
async fn enforce_role(
&self,
uid: &i64,
workspace_id: &Uuid,
executor: E,
) -> Result<AFRole, AppError>
where
E: Executor<'a, Database = Postgres>;
async fn insert_workspace_role(
&self,
uid: &i64,
workspace_id: &Uuid,
workspace_id: &str,
role: AFRole,
) -> Result<(), AppError>;
) -> Result<bool, AppError>;
async fn enforce_action(
&self,
uid: &i64,
workspace_id: &str,
action: Action,
) -> Result<bool, AppError>;
async fn insert_role(&self, uid: &i64, workspace_id: &Uuid, role: AFRole)
-> Result<(), AppError>;
async fn remove_role(&self, uid: &i64, workspace_id: &Uuid) -> Result<(), AppError>;
}
#[derive(Clone)]
pub struct WorkspaceHttpAccessControl<AC: WorkspaceAccessControl> {
pub struct WorkspaceMiddlewareAccessControl<AC: WorkspaceAccessControl> {
pub pg_pool: PgPool,
pub access_control: Arc<AC>,
skip_resources: Vec<(Method, ResourceDef)>,
require_role_rules: Vec<(ResourceDef, HashMap<Method, AFRole>)>,
}
impl<AC> WorkspaceMiddlewareAccessControl<AC>
where
AC: WorkspaceAccessControl,
{
pub fn new(pg_pool: PgPool, access_control: Arc<AC>) -> Self {
Self {
pg_pool,
// Skip access control when the request matches the following resources
skip_resources: vec![
// Skip access control when the request is a POST request and the path is matched with the WORKSPACE_PATTERN,
(Method::POST, ResourceDef::new(WORKSPACE_PATTERN)),
],
// Require role for given resources
require_role_rules: vec![
// Only the Owner can manager the workspace members
(
ResourceDef::new(WORKSPACE_MEMBER_PATTERN),
[
(Method::POST, AFRole::Owner),
(Method::DELETE, AFRole::Owner),
(Method::PUT, AFRole::Owner),
(Method::GET, AFRole::Owner),
]
.into(),
),
],
access_control,
}
}
fn should_skip(&self, method: &Method, path: &Path<Url>) -> bool {
self.skip_resources.iter().any(|(m, r)| {
if m != method {
return false;
}
r.is_match(path.as_str())
})
}
fn require_role(&self, method: &Method, path: &Path<Url>) -> Option<AFRole> {
self.require_role_rules.iter().find_map(|(r, roles)| {
if r.is_match(path.as_str()) {
roles.get(method).cloned()
} else {
None
}
})
}
}
#[async_trait]
impl<AC> HttpAccessControlService for WorkspaceHttpAccessControl<AC>
impl<AC> MiddlewareAccessControl for WorkspaceMiddlewareAccessControl<AC>
where
AC: WorkspaceAccessControl,
{
@ -53,47 +110,52 @@ where
AccessResource::Workspace
}
#[instrument(level = "trace", skip_all, err)]
async fn check_workspace_permission(
#[instrument(name = "check_workspace_permission", level = "trace", skip_all)]
async fn check_resource_permission(
&self,
workspace_id: &Uuid,
uid: &i64,
method: Method,
) -> Result<(), AppError> {
trace!("workspace_id: {:?}, uid: {:?}", workspace_id, uid);
let role = self
.access_control
.get_workspace_role(uid, workspace_id, &self.pg_pool)
.await
.map_err(|err| {
AppError::NotEnoughPermissions(format!(
"Can't find the role of the user:{:?} in the workspace:{:?}. error: {}",
uid, workspace_id, err
))
})?;
match method {
Method::DELETE | Method::POST | Method::PUT => match role {
AFRole::Owner => return Ok(()),
_ => {
return Err(AppError::NotEnoughPermissions(format!(
"User:{:?} doesn't have the enough permission to access workspace:{}",
uid, workspace_id
)))
},
},
_ => Ok(()),
}
}
async fn check_collab_permission(
&self,
oid: &str,
uid: &i64,
resource_id: &str,
method: Method,
path: &Path<Url>,
) -> Result<(), AppError> {
error!("The check_collab_permission is not implemented");
Ok(())
if self.should_skip(&method, path) {
trace!("Skip access control for the request");
return Ok(());
}
// For some specific resources, we require a specific role to access them instead of the action.
// For example, Both AFRole::Owner and AFRole::Member have the write permission to the workspace,
// but only the Owner can manage the workspace members.
let require_role = self.require_role(&method, path);
let result = match require_role {
Some(role) => {
self
.access_control
.enforce_role(uid, resource_id, role)
.await
},
None => {
// If the request doesn't match any specific resources, we enforce the action.
let action = Action::from(&method);
self
.access_control
.enforce_action(uid, resource_id, action)
.await
},
}?;
if result {
Ok(())
} else {
Err(AppError::NotEnoughPermissions {
user: uid.to_string(),
action: format!(
"access workspace:{} with given url:{}, method: {}",
resource_id,
path.as_str(),
method,
),
})
}
}
}

View file

@ -9,11 +9,11 @@ use database::resource_usage::get_all_workspace_blob_metadata;
use database::user::select_uid_from_email;
use database::workspace::{
change_workspace_icon, delete_from_workspace, delete_workspace_members, insert_user_workspace,
insert_workspace_member_with_txn, rename_workspace, select_all_user_workspaces, select_workspace,
select_workspace_member_list, update_updated_at_of_workspace, upsert_workspace_member,
rename_workspace, select_all_user_workspaces, select_workspace, select_workspace_member_list,
update_updated_at_of_workspace, upsert_workspace_member, upsert_workspace_member_with_txn,
};
use database_entity::dto::{AFAccessLevel, AFRole, AFWorkspace};
use realtime::collaborate::CollabAccessControl;
use shared_entity::dto::workspace_dto::{CreateWorkspaceMember, WorkspaceMemberChangeset};
use shared_entity::response::AppResponseError;
use sqlx::{types::uuid, PgPool};
@ -24,7 +24,7 @@ use tracing::instrument;
use uuid::Uuid;
use workspace_template::document::get_started::GetStartedDocumentTemplate;
use crate::biz::collab::storage::CollabStorageImpl;
use crate::biz::collab::storage::CollabAccessControlStorage;
use crate::biz::user::initialize_workspace_for_user;
pub async fn delete_workspace_for_user(
@ -57,8 +57,7 @@ pub async fn delete_workspace_for_user(
pub async fn create_workspace_for_user(
pg_pool: &PgPool,
workspace_access_control: &impl WorkspaceAccessControl,
collab_access_control: &impl CollabAccessControl,
collab_storage: &Arc<CollabStorageImpl>,
collab_storage: &Arc<CollabAccessControlStorage>,
user_uuid: &Uuid,
user_uid: i64,
workspace_name: &str,
@ -68,16 +67,9 @@ pub async fn create_workspace_for_user(
let new_workspace = AFWorkspace::try_from(new_workspace_row)?;
workspace_access_control
.insert_workspace_role(&user_uid, &new_workspace.workspace_id, AFRole::Owner)
.insert_role(&user_uid, &new_workspace.workspace_id, AFRole::Owner)
.await?;
collab_access_control
.insert_collab_access_level(
&user_uid,
&new_workspace.workspace_id.to_string(),
AFAccessLevel::FullAccess,
)
.await?;
// add create initial collab for user
initialize_workspace_for_user(
user_uid,
@ -170,18 +162,9 @@ pub async fn add_workspace_members(
let mut role_by_uid = HashMap::new();
for member in members.into_iter() {
let access_level = match &member.role {
AFRole::Owner => AFAccessLevel::FullAccess,
AFRole::Member => AFAccessLevel::ReadAndWrite,
AFRole::Guest => AFAccessLevel::ReadOnly,
};
let access_level = AFAccessLevel::from(&member.role);
let uid = select_uid_from_email(txn.deref_mut(), &member.email).await?;
// .context(format!(
// "Failed to get uid from email {} when adding workspace members",
// member.email
// ))?;
insert_workspace_member_with_txn(&mut txn, workspace_id, &member.email, member.role.clone())
upsert_workspace_member_with_txn(&mut txn, workspace_id, &member.email, member.role.clone())
.await?;
upsert_collab_member_with_txn(uid, workspace_id.to_string(), &access_level, &mut txn).await?;
role_by_uid.insert(uid, member.role);
@ -189,7 +172,7 @@ pub async fn add_workspace_members(
for (uid, role) in role_by_uid {
workspace_access_control
.insert_workspace_role(&uid, workspace_id, role)
.insert_role(&uid, workspace_id, role)
.await?;
}
txn
@ -200,10 +183,10 @@ pub async fn add_workspace_members(
}
pub async fn remove_workspace_members(
user_uuid: &Uuid,
pg_pool: &PgPool,
workspace_id: &Uuid,
member_emails: &[String],
workspace_access_control: &impl WorkspaceAccessControl,
) -> Result<(), AppResponseError> {
let mut txn = pg_pool
.begin()
@ -211,7 +194,15 @@ pub async fn remove_workspace_members(
.context("Begin transaction to delete workspace members")?;
for email in member_emails {
delete_workspace_members(user_uuid, &mut txn, workspace_id, email.as_str()).await?;
delete_workspace_members(&mut txn, workspace_id, email.as_str()).await?;
if let Ok(uid) = select_uid_from_email(txn.deref_mut(), email)
.await
.map_err(AppResponseError::from)
{
workspace_access_control
.remove_role(&uid, workspace_id)
.await?;
}
}
txn
@ -230,16 +221,18 @@ pub async fn get_workspace_members(
}
pub async fn update_workspace_member(
uid: &i64,
pg_pool: &PgPool,
workspace_id: &Uuid,
changeset: &WorkspaceMemberChangeset,
workspace_access_control: &impl WorkspaceAccessControl,
) -> Result<(), AppError> {
upsert_workspace_member(
pg_pool,
workspace_id,
&changeset.email,
changeset.role.clone(),
)
.await?;
if let Some(role) = &changeset.role {
upsert_workspace_member(pg_pool, workspace_id, &changeset.email, role.clone()).await?;
workspace_access_control
.insert_role(uid, workspace_id, role.clone())
.await?;
}
Ok(())
}

View file

@ -14,7 +14,7 @@ use dashmap::DashMap;
use once_cell::sync::Lazy;
use std::collections::HashMap;
use std::future::{ready, Ready};
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use tracing::error;
@ -37,103 +37,42 @@ pub enum AccessResource {
/// The collab and workspace access control can be separated into different traits. Currently, they are
/// combined into one trait.
#[async_trait]
pub trait HttpAccessControlService: Send + Sync {
pub trait MiddlewareAccessControl: Send + Sync {
fn resource(&self) -> AccessResource;
#[allow(unused_variables)]
async fn check_workspace_permission(
async fn check_resource_permission(
&self,
workspace_id: &Uuid,
uid: &i64,
method: Method,
) -> Result<(), AppError>;
#[allow(unused_variables)]
async fn check_collab_permission(
&self,
oid: &str,
uid: &i64,
resource_id: &str,
method: Method,
path: &Path<Url>,
) -> Result<(), AppError>;
}
#[async_trait]
impl<T> HttpAccessControlService for Arc<T>
where
T: HttpAccessControlService,
{
fn resource(&self) -> AccessResource {
self.as_ref().resource()
}
async fn check_workspace_permission(
&self,
workspace_id: &Uuid,
uid: &i64,
method: Method,
) -> Result<(), AppError> {
self
.as_ref()
.check_workspace_permission(workspace_id, uid, method)
.await
}
async fn check_collab_permission(
&self,
oid: &str,
uid: &i64,
method: Method,
path: &Path<Url>,
) -> Result<(), AppError> {
self
.as_ref()
.check_collab_permission(oid, uid, method, path)
.await
}
}
pub type HttpAccessControlServices =
Arc<HashMap<AccessResource, Arc<dyn HttpAccessControlService>>>;
/// Implement the access control for the workspace and collab.
/// It will check the permission of the request if the request is related to workspace or collab.
#[derive(Clone, Default)]
pub struct WorkspaceAccessControl {
access_control_services: HttpAccessControlServices,
pub struct MiddlewareAccessControlTransform {
controllers: Arc<HashMap<AccessResource, Arc<dyn MiddlewareAccessControl>>>,
}
impl WorkspaceAccessControl {
impl MiddlewareAccessControlTransform {
pub fn new() -> Self {
Self::default()
}
pub fn with_acs<T: HttpAccessControlService + 'static>(
pub fn with_acs<T: MiddlewareAccessControl + 'static>(
mut self,
access_control_service: T,
) -> Self {
let resource = access_control_service.resource();
Arc::make_mut(&mut self.access_control_services)
.insert(resource, Arc::new(access_control_service));
Arc::make_mut(&mut self.controllers).insert(resource, Arc::new(access_control_service));
self
}
}
impl Deref for WorkspaceAccessControl {
type Target = HttpAccessControlServices;
fn deref(&self) -> &Self::Target {
&self.access_control_services
}
}
impl DerefMut for WorkspaceAccessControl {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.access_control_services
}
}
impl<S, B> Transform<S, ServiceRequest> for WorkspaceAccessControl
impl<S, B> Transform<S, ServiceRequest> for MiddlewareAccessControlTransform
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
@ -141,14 +80,14 @@ where
{
type Response = ServiceResponse<B>;
type Error = Error;
type Transform = WorkspaceAccessControlMiddleware<S>;
type Transform = AccessControlMiddleware<S>;
type InitError = ();
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(WorkspaceAccessControlMiddleware {
ready(Ok(AccessControlMiddleware {
service,
access_control_service: self.access_control_services.clone(),
controllers: self.controllers.clone(),
}))
}
}
@ -158,15 +97,15 @@ where
/// are used to identify the workspace and collab.
///
/// For example, if the request path is `/api/workspace/{workspace_id}/collab/{object_id}`, then the
/// [WorkspaceAccessControlMiddleware] will check the permission of the workspace and collab.
/// [AccessControlMiddleware] will check the permission of the workspace and collab.
///
///
pub struct WorkspaceAccessControlMiddleware<S> {
pub struct AccessControlMiddleware<S> {
service: S,
access_control_service: HttpAccessControlServices,
controllers: Arc<HashMap<AccessResource, Arc<dyn MiddlewareAccessControl>>>,
}
impl<S, B> Service<ServiceRequest> for WorkspaceAccessControlMiddleware<S>
impl<S, B> Service<ServiceRequest> for AccessControlMiddleware<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
@ -199,7 +138,7 @@ where
let user_uuid = req.extract::<UserUuid>();
let user_cache = req
.app_data::<Data<AppState>>()
.map(|state| state.users.clone());
.map(|state| state.user_cache.clone());
let uid = async {
let user_uuid = user_uuid.await.map_err(|err| {
@ -220,50 +159,47 @@ where
let workspace_id = path
.get(WORKSPACE_ID_PATH)
.and_then(|id| Uuid::parse_str(id).ok());
let collab_object_id = path.get(COLLAB_OBJECT_ID_PATH).map(|id| id.to_string());
let object_id = path.get(COLLAB_OBJECT_ID_PATH).map(|id| id.to_string());
let method = req.method().clone();
let fut = self.service.call(req);
let services = self.access_control_service.clone();
let services = self.controllers.clone();
Box::pin(async move {
// If the workspace_id or collab_object_id is not present, skip the access control
if workspace_id.is_some() || collab_object_id.is_some() {
let uid = uid.await?;
if workspace_id.is_none() && object_id.is_none() {
return fut.await;
}
// check workspace permission
if let Some(workspace_id) = workspace_id {
if let Some(acs) = services.get(&AccessResource::Workspace) {
if let Err(err) = acs
.check_workspace_permission(&workspace_id, &uid, method.clone())
.await
{
error!(
"workspace access control: {}, with path:{}",
err,
path.as_str()
);
return Err(Error::from(err));
}
};
}
let uid = uid.await?;
// check workspace permission
if let Some(workspace_id) = workspace_id {
if let Some(workspace_ac) = services.get(&AccessResource::Workspace) {
if let Err(err) = workspace_ac
.check_resource_permission(&uid, &workspace_id.to_string(), method.clone(), &path)
.await
{
error!("workspace access control: {}", err,);
return Err(Error::from(err));
}
};
}
// check collab permission
if let Some(collab_object_id) = collab_object_id {
if let Some(acs) = services.get(&AccessResource::Collab) {
if let Err(err) = acs
.check_collab_permission(&collab_object_id, &uid, method, &path)
.await
{
error!(
"collab access control: {:?}, with path:{}",
err,
path.as_str()
);
return Err(Error::from(err));
}
};
}
// check collab permission
if let Some(collab_object_id) = object_id {
if let Some(collab_ac) = services.get(&AccessResource::Collab) {
if let Err(err) = collab_ac
.check_resource_permission(&uid, &collab_object_id, method, &path)
.await
{
error!(
"collab access control: {:?}, with path:{}",
err,
path.as_str()
);
return Err(Error::from(err));
}
};
}
// call next service

View file

@ -1,5 +1,5 @@
use crate::biz::casbin::{CollabAccessControlImpl, WorkspaceAccessControlImpl};
use crate::biz::collab::storage::CollabStorageImpl;
use crate::biz::collab::storage::CollabAccessControlStorage;
use crate::biz::pg_listener::PgListeners;
use crate::config::config::Config;
@ -15,6 +15,7 @@ use realtime::collaborate::RealtimeMetrics;
use snowflake::Snowflake;
use sqlx::PgPool;
use crate::biz::collab::cache::CollabCache;
use crate::biz::collab::metrics::CollabMetrics;
use std::sync::Arc;
use tokio::sync::RwLock;
@ -27,11 +28,12 @@ pub type RedisClient = redis::aio::ConnectionManager;
pub struct AppState {
pub pg_pool: PgPool,
pub config: Arc<Config>,
pub users: Arc<UserCache>,
pub user_cache: UserCache,
pub id_gen: Arc<RwLock<Snowflake>>,
pub gotrue_client: gotrue::api::Client,
pub redis_client: RedisClient,
pub collab_storage: Arc<CollabStorageImpl>,
pub collab_cache: CollabCache,
pub collab_access_control_storage: Arc<CollabAccessControlStorage>,
pub collab_access_control: CollabAccessControlImpl,
pub workspace_access_control: WorkspaceAccessControlImpl,
pub bucket_storage: Arc<S3BucketStorage>,
@ -56,9 +58,10 @@ pub struct AuthenticateUser {
pub const EXPIRED_DURATION_DAYS: i64 = 30;
#[derive(Clone)]
pub struct UserCache {
pool: PgPool,
users: DashMap<Uuid, AuthenticateUser>,
users: Arc<DashMap<Uuid, AuthenticateUser>>,
}
impl UserCache {
@ -78,7 +81,10 @@ impl UserCache {
users
};
Self { pool, users }
Self {
pool,
users: Arc::new(users),
}
}
/// Get the user's uid from the cache or the database.

View file

@ -1,386 +0,0 @@
use crate::access_control::*;
use actix_http::Method;
use anyhow::{anyhow, Context};
use appflowy_cloud::biz;
use appflowy_cloud::biz::casbin::access_control::{Action, ActionType, ObjectType};
use database_entity::dto::{AFAccessLevel, AFRole};
use realtime::collaborate::CollabAccessControl;
use serial_test::serial;
use shared_entity::dto::workspace_dto::CreateWorkspaceMember;
use sqlx::PgPool;
use std::time::Duration;
use tokio::time::sleep;
#[sqlx::test(migrations = false)]
#[serial]
async fn test_collab_access_control(pool: PgPool) -> anyhow::Result<()> {
let access_control = setup_access_control(&pool).await?;
let collab_access_control = access_control.new_collab_access_control();
let workspace_access_control = access_control.new_workspace_access_control();
let user = create_user(&pool).await?;
let owner = create_user(&pool).await?;
let member = create_user(&pool).await?;
let guest = create_user(&pool).await?;
// Get workspace details
let workspace = database::workspace::select_user_workspace(&pool, &user.uuid)
.await?
.into_iter()
.next()
.ok_or(anyhow!("workspace should be created"))?;
let members = vec![
CreateWorkspaceMember {
email: owner.email.clone(),
role: AFRole::Owner,
},
CreateWorkspaceMember {
email: member.email.clone(),
role: AFRole::Member,
},
CreateWorkspaceMember {
email: guest.email.clone(),
role: AFRole::Guest,
},
];
biz::workspace::ops::add_workspace_members(
&pool,
&user.uuid,
&workspace.workspace_id,
members,
&workspace_access_control,
)
.await
.context("adding users to workspace")?;
// user that created the workspace should have full access
assert_access_level(
&collab_access_control,
&user.uid,
workspace.workspace_id.to_string(),
Some(AFAccessLevel::FullAccess),
)
.await;
// member should have read and write access
assert_access_level(
&collab_access_control,
&member.uid,
workspace.workspace_id.to_string(),
Some(AFAccessLevel::ReadAndWrite),
)
.await;
// guest should have read access
assert_access_level(
&collab_access_control,
&guest.uid,
workspace.workspace_id.to_string(),
Some(AFAccessLevel::ReadOnly),
)
.await;
let mut txn = pool
.begin()
.await
.context("acquire transaction to update collab member")?;
// update guest access level to read and comment
database::collab::upsert_collab_member_with_txn(
guest.uid,
&workspace.workspace_id.to_string(),
&AFAccessLevel::ReadAndComment,
&mut txn,
)
.await?;
txn
.commit()
.await
.expect("commit transaction to update collab member");
// guest should have read and comment access
assert_access_level(
&collab_access_control,
&guest.uid,
workspace.workspace_id.to_string(),
Some(AFAccessLevel::ReadAndComment),
)
.await;
database::collab::delete_collab_member(guest.uid, &workspace.workspace_id.to_string(), &pool)
.await
.context("delete collab member")?;
// guest should not have access after removed from collab
assert_access_level(
&collab_access_control,
&guest.uid,
workspace.workspace_id.to_string(),
None,
)
.await;
Ok(())
}
#[sqlx::test(migrations = false)]
#[serial]
async fn test_collab_access_control_when_obj_not_exist(pool: PgPool) -> anyhow::Result<()> {
let access_control = setup_access_control(&pool).await?;
let collab_access_control = access_control.new_collab_access_control();
let user = create_user(&pool).await?;
for method in [Method::GET, Method::POST, Method::PUT, Method::DELETE] {
assert_can_access_http_method(&collab_access_control, &user.uid, "fake_id", method, true)
.await
.unwrap();
}
Ok(())
}
#[sqlx::test(migrations = false)]
#[serial]
async fn test_collab_access_control_access_http_method(pool: PgPool) -> anyhow::Result<()> {
let access_control = setup_access_control(&pool).await?;
let collab_access_control = access_control.new_collab_access_control();
let workspace_access_control = access_control.new_workspace_access_control();
let user = create_user(&pool).await?;
let guest = create_user(&pool).await?;
let stranger = create_user(&pool).await?;
// Get workspace details
let workspace = database::workspace::select_user_workspace(&pool, &user.uuid)
.await?
.into_iter()
.next()
.ok_or(anyhow!("workspace should be created"))?;
biz::workspace::ops::add_workspace_members(
&pool,
&guest.uuid,
&workspace.workspace_id,
vec![CreateWorkspaceMember {
email: guest.email,
role: AFRole::Guest,
}],
&workspace_access_control,
)
.await
.context("adding users to workspace")
.unwrap();
for method in [Method::GET, Method::POST, Method::PUT, Method::DELETE] {
assert_can_access_http_method(
&collab_access_control,
&user.uid,
&workspace.workspace_id.to_string(),
method,
true,
)
.await
.unwrap();
}
assert!(
collab_access_control
.can_access_http_method(&user.uid, "new collab oid", &Method::POST)
.await?,
"should have access to non-existent collab oid"
);
// guest should have read access
assert_can_access_http_method(
&collab_access_control,
&guest.uid,
&workspace.workspace_id.to_string(),
Method::GET,
true,
)
.await
.unwrap();
// guest should not have write access
assert_can_access_http_method(
&collab_access_control,
&guest.uid,
&workspace.workspace_id.to_string(),
Method::POST,
false,
)
.await
.unwrap();
assert!(
!collab_access_control
.can_access_http_method(
&stranger.uid,
&workspace.workspace_id.to_string(),
&Method::GET
)
.await?,
"stranger should not have read access"
);
//
assert!(
!collab_access_control
.can_access_http_method(
&stranger.uid,
&workspace.workspace_id.to_string(),
&Method::POST
)
.await?,
"stranger should not have write access"
);
Ok(())
}
#[sqlx::test(migrations = false)]
#[serial]
async fn test_collab_access_control_send_receive_collab_update(pool: PgPool) -> anyhow::Result<()> {
let access_control = setup_access_control(&pool).await?;
let collab_access_control = access_control.new_collab_access_control();
let workspace_access_control = access_control.new_workspace_access_control();
let user = create_user(&pool).await?;
let guest = create_user(&pool).await?;
let stranger = create_user(&pool).await?;
// Get workspace details
let workspace = database::workspace::select_user_workspace(&pool, &user.uuid)
.await?
.into_iter()
.next()
.ok_or(anyhow!("workspace should be created"))?;
biz::workspace::ops::add_workspace_members(
&pool,
&guest.uuid,
&workspace.workspace_id,
vec![CreateWorkspaceMember {
email: guest.email,
role: AFRole::Guest,
}],
&workspace_access_control,
)
.await
.context("adding users to workspace")?;
// Need to wait for the listener(spawn_listen_on_workspace_member_change) to receive the event
sleep(Duration::from_secs(2)).await;
assert!(
collab_access_control
.can_send_collab_update(&user.uid, &workspace.workspace_id.to_string())
.await?
);
assert!(
collab_access_control
.can_receive_collab_update(&user.uid, &workspace.workspace_id.to_string())
.await?
);
assert!(
!collab_access_control
.can_send_collab_update(&guest.uid, &workspace.workspace_id.to_string())
.await?,
"guest cannot send collab update"
);
assert!(
collab_access_control
.can_receive_collab_update(&guest.uid, &workspace.workspace_id.to_string())
.await?,
"guest can receive collab update"
);
assert!(
!collab_access_control
.can_send_collab_update(&stranger.uid, &workspace.workspace_id.to_string())
.await?,
"stranger cannot send collab update"
);
assert!(
!collab_access_control
.can_receive_collab_update(&stranger.uid, &workspace.workspace_id.to_string())
.await?,
"stranger cannot receive collab update"
);
Ok(())
}
#[sqlx::test(migrations = false)]
#[serial]
async fn test_collab_access_control_cache_collab_access_level(pool: PgPool) -> anyhow::Result<()> {
let access_control = setup_access_control(&pool).await?;
let collab_access_control = access_control.new_collab_access_control();
let uid = 123;
let oid = "collab::oid".to_owned();
collab_access_control
.insert_collab_access_level(&uid, &oid, AFAccessLevel::FullAccess)
.await?;
assert_eq!(
AFAccessLevel::FullAccess,
collab_access_control
.get_collab_access_level(&uid, &oid)
.await?
);
collab_access_control
.insert_collab_access_level(&uid, &oid, AFAccessLevel::ReadOnly)
.await?;
assert_eq!(
AFAccessLevel::ReadOnly,
collab_access_control
.get_collab_access_level(&uid, &oid)
.await?
);
Ok(())
}
#[sqlx::test(migrations = false)]
#[serial]
async fn test_casbin_access_control_update_remove(pool: PgPool) -> anyhow::Result<()> {
let access_control = setup_access_control(&pool).await?;
let uid = 123;
assert!(
access_control
.update(
&uid,
&ObjectType::Workspace("123"),
&ActionType::Role(AFRole::Owner)
)
.await?
);
assert!(
access_control
.enforce(&uid, &ObjectType::Workspace("123"), Action::Write)
.await?
);
assert!(access_control
.remove(&uid, &ObjectType::Workspace("123"))
.await
.is_ok());
assert!(
access_control
.enforce(&uid, &ObjectType::Workspace("123"), Action::Read)
.await?
);
Ok(())
}

View file

@ -1,101 +0,0 @@
use crate::access_control::{
assert_workspace_role, assert_workspace_role_error, create_user, setup_access_control,
};
use anyhow::{anyhow, Context};
use app_error::ErrorCode;
use appflowy_cloud::biz;
use serial_test::serial;
use database_entity::dto::AFRole;
use shared_entity::dto::workspace_dto::{CreateWorkspaceMember, WorkspaceMemberChangeset};
use sqlx::PgPool;
#[sqlx::test(migrations = false)]
#[serial]
async fn test_workspace_access_control_get_role(pool: PgPool) -> anyhow::Result<()> {
let access_control = setup_access_control(&pool).await?;
let workspace_access_control = access_control.new_workspace_access_control();
let user = create_user(&pool).await?;
// Get workspace details
let workspace = database::workspace::select_user_workspace(&pool, &user.uuid)
.await?
.into_iter()
.next()
.ok_or(anyhow!("workspace should be created"))?;
assert_workspace_role(
&workspace_access_control,
&user.uid,
&workspace.workspace_id,
Some(AFRole::Owner),
&pool,
)
.await;
let member = create_user(&pool).await?;
biz::workspace::ops::add_workspace_members(
&pool,
&member.uuid,
&workspace.workspace_id,
vec![CreateWorkspaceMember {
email: member.email.clone(),
role: AFRole::Member,
}],
&workspace_access_control,
)
.await
.context("adding users to workspace")?;
assert_workspace_role(
&workspace_access_control,
&member.uid,
&workspace.workspace_id,
Some(AFRole::Member),
&pool,
)
.await;
// wait for update message
biz::workspace::ops::update_workspace_member(
&pool,
&workspace.workspace_id,
&WorkspaceMemberChangeset {
email: member.email.clone(),
role: Some(AFRole::Guest),
name: None,
},
)
.await
.context("update user workspace role")?;
assert_workspace_role(
&workspace_access_control,
&member.uid,
&workspace.workspace_id,
Some(AFRole::Guest),
&pool,
)
.await;
biz::workspace::ops::remove_workspace_members(
&user.uuid,
&pool,
&workspace.workspace_id,
&[member.email.clone()],
)
.await
.context("removing users from workspace")?;
assert_workspace_role_error(
&workspace_access_control,
&member.uid,
&workspace.workspace_id,
ErrorCode::RecordNotFound,
&pool,
)
.await;
Ok(())
}

View file

@ -1,337 +0,0 @@
use actix_http::Method;
use anyhow::{Context, Error};
use app_error::{AppError, ErrorCode};
use appflowy_cloud::biz::casbin::{
AFEnforcerCache, ActionCacheKey, CollabAccessControlImpl, PolicyCacheKey,
WorkspaceAccessControlImpl,
};
use appflowy_cloud::biz::workspace::access_control::WorkspaceAccessControl;
use client_api_test_util::setup_log;
use database_entity::dto::{AFAccessLevel, AFRole};
use lazy_static::lazy_static;
use realtime::collaborate::CollabAccessControl;
use snowflake::Snowflake;
use sqlx::PgPool;
use std::sync::Arc;
use async_trait::async_trait;
use dashmap::DashMap;
use std::time::Duration;
use tokio::sync::RwLock;
use tokio::time::{interval, timeout};
use appflowy_cloud::biz::casbin::access_control::AccessControl;
use appflowy_cloud::biz::pg_listener::PgListeners;
use appflowy_cloud::state::AppMetrics;
use uuid::Uuid;
mod collab_ac_test;
mod member_ac_test;
mod user_ac_test;
lazy_static! {
pub static ref ID_GEN: RwLock<Snowflake> = RwLock::new(Snowflake::new(1));
}
pub async fn setup_access_control(pool: &PgPool) -> anyhow::Result<AccessControl> {
setup_db(pool).await?;
let metrics = AppMetrics::new();
let listeners = PgListeners::new(pool).await?;
let enforcer_cache = Arc::new(TestEnforcerCacheImpl {
cache: DashMap::new(),
});
Ok(
AccessControl::new(
pool.clone(),
listeners.subscribe_collab_member_change(),
listeners.subscribe_workspace_member_change(),
metrics.access_control_metrics,
enforcer_cache,
)
.await
.unwrap(),
)
}
pub async fn setup_db(pool: &PgPool) -> anyhow::Result<()> {
setup_log();
// Have to manually manage schema and tables managed by gotrue but referenced by our
// migration scripts.
// Create schema and tables
sqlx::query(r#"create schema auth"#).execute(pool).await?;
sqlx::query(
r#"create table auth.users(
id uuid NOT NULL UNIQUE,
deleted_at timestamptz null,
CONSTRAINT users_pkey PRIMARY KEY (id)
)"#,
)
.execute(pool)
.await?;
// Manually run migration after creating required objects above.
sqlx::migrate!().run(pool).await?;
// Remove foreign key constraint
sqlx::query(r#"alter table public.af_user drop constraint af_user_email_foreign_key"#)
.execute(pool)
.await?;
Ok(())
}
#[derive(Debug, Clone)]
pub struct User {
pub uid: i64,
pub uuid: Uuid,
pub email: String,
}
pub async fn create_user(pool: &PgPool) -> anyhow::Result<User> {
// Create user and workspace
let uid = ID_GEN.write().await.next_id();
let uuid = Uuid::new_v4();
let email = format!("{}@appflowy.io", uuid);
let name = uuid.to_string();
database::user::create_user(pool, uid, &uuid, &email, &name)
.await
.context("create user")?;
Ok(User { uid, uuid, email })
}
/// Asserts that the user has the specified access level within a workspace.
///
/// This function continuously checks the user's access level in a workspace and asserts that it
/// matches the expected level. The function retries the check a fixed number of times before timing out.
///
/// # Panics
/// Panics if the expected access level is not achieved before the timeout.
pub async fn assert_access_level<T: AsRef<str>>(
access_control: &CollabAccessControlImpl,
uid: &i64,
workspace_id: T,
expected_level: Option<AFAccessLevel>,
) {
let mut retry_count = 0;
loop {
tokio::select! {
_ = tokio::time::sleep(Duration::from_secs(10)) => {
panic!("can't get the expected access level before timeout");
},
result = access_control
.get_collab_access_level(
uid,
workspace_id.as_ref(),
)
=> {
retry_count += 1;
match result {
Ok(access_level) => {
if retry_count > 10 {
assert_eq!(access_level, expected_level.unwrap());
break;
}
if let Some(expected_level) = expected_level {
if access_level == expected_level {
break;
}
}
tokio::time::sleep(Duration::from_millis(300)).await;
},
Err(err) => {
if err.is_record_not_found() & expected_level.is_none() {
break;
}
tokio::time::sleep(Duration::from_millis(1000)).await;
}
}
},
}
}
}
/// Asserts that the user has the specified role within a workspace.
///
/// This function continuously checks the user's role in a workspace and asserts that it matches
/// the expected role. It retries the check a fixed number of times before timing out.
///
/// # Panics
/// Panics if the expected role is not achieved before the timeout.
pub async fn assert_workspace_role(
access_control: &WorkspaceAccessControlImpl,
uid: &i64,
workspace_id: &Uuid,
expected_role: Option<AFRole>,
pg_pool: &PgPool,
) {
let mut retry_count = 0;
let timeout = Duration::from_secs(10);
let start_time = tokio::time::Instant::now();
loop {
if retry_count > 10 {
// This check should be outside of the select! block to prevent panic before checking the condition.
panic!("Exceeded maximum number of retries");
}
if start_time.elapsed() > timeout {
panic!("can't get the expected role before timeout");
}
match access_control
.get_workspace_role(uid, workspace_id, pg_pool)
.await
{
Ok(role) if Some(&role) == expected_role.as_ref() => {
// If the roles match, or if the expected role is None and no role is found, break the loop
break;
},
Err(err) if err.is_record_not_found() && expected_role.is_none() => {
// If no record is found and no role is expected, break the loop
break;
},
Err(err) if err.is_record_not_found() => {
// If no record is found but a role is expected, wait and retry
tokio::time::sleep(Duration::from_millis(1000)).await;
},
_ => {
// If the roles do not match, or any other error occurs, wait and retry
tokio::time::sleep(Duration::from_millis(300)).await;
},
}
retry_count += 1;
}
}
/// Asserts that retrieving the user's role within a workspace results in a specific error.
///
/// This function continuously attempts to fetch the user's role in a workspace, expecting a specific
/// error. If the expected error does not occur within a certain number of retries, it panics.
///
/// # Panics
/// Panics if the expected error is not encountered before the timeout or if an unexpected role is received.
pub async fn assert_workspace_role_error(
access_control: &WorkspaceAccessControlImpl,
uid: &i64,
workspace_id: &Uuid,
expected_error: ErrorCode,
pg_pool: &PgPool,
) {
let timeout_duration = Duration::from_secs(10);
let retry_interval = Duration::from_millis(300);
let mut retries = 0usize;
let max_retries = 10;
let operation = async {
let mut interval = interval(retry_interval);
loop {
interval.tick().await; // Wait for the next interval tick before retrying
match access_control
.get_workspace_role(uid, workspace_id, pg_pool)
.await
{
Ok(_) => {},
Err(err) if err.code() == expected_error => {
// If the error matches the expected error, exit successfully
return;
},
Err(_) => {
retries += 1;
if retries > max_retries {
// If retries exceed the maximum, return an error
panic!("Exceeded maximum number of retries without encountering the expected error");
}
// On any other error, continue retrying
},
}
}
};
timeout(timeout_duration, operation)
.await
.expect("Operation timed out");
}
pub async fn assert_can_access_http_method(
access_control: &CollabAccessControlImpl,
uid: &i64,
object_id: &str,
method: Method,
expected: bool,
) -> Result<(), Error> {
let timeout_duration = Duration::from_secs(10);
let retry_interval = Duration::from_millis(1000);
let mut retries = 0usize;
let max_retries = 10;
let operation = async {
let mut interval = interval(retry_interval);
loop {
interval.tick().await; // Wait for the next interval tick before retrying
match access_control
.can_access_http_method(uid, object_id, &method)
.await
{
Ok(access) => {
if access == expected {
break;
}
},
Err(_) => {
retries += 1;
if retries > max_retries {
// If retries exceed the maximum, return an error
panic!("Exceeded maximum number of retries without encountering the expected error");
}
// On any other error, continue retrying
},
}
}
};
timeout(timeout_duration, operation).await?;
Ok(())
}
struct TestEnforcerCacheImpl {
cache: DashMap<String, String>,
}
#[async_trait]
impl AFEnforcerCache for TestEnforcerCacheImpl {
async fn set_enforcer_result(&self, key: &PolicyCacheKey, value: bool) -> Result<(), AppError> {
self
.cache
.insert(key.as_ref().to_string(), value.to_string());
Ok(())
}
async fn get_enforcer_result(&self, key: &PolicyCacheKey) -> Option<bool> {
self
.cache
.get(key.as_ref())
.map(|v| v.value().parse().unwrap())
}
async fn remove_enforcer_result(&self, key: &PolicyCacheKey) {
self.cache.remove(key.as_ref());
}
async fn set_action(&self, key: &ActionCacheKey, value: String) -> Result<(), AppError> {
self.cache.insert(key.as_ref().to_string(), value);
Ok(())
}
async fn get_action(&self, key: &ActionCacheKey) -> Option<String> {
self.cache.get(key.as_ref()).map(|v| v.value().to_string())
}
async fn remove_action(&self, key: &ActionCacheKey) {
self.cache.remove(key.as_ref());
}
}

View file

@ -1,230 +0,0 @@
use crate::access_control::*;
use anyhow::anyhow;
use appflowy_cloud::biz;
use appflowy_cloud::biz::casbin::access_control::{Action, ObjectType};
use database_entity::dto::{AFAccessLevel, AFRole};
use serial_test::serial;
use shared_entity::dto::workspace_dto::CreateWorkspaceMember;
use sqlx::PgPool;
#[sqlx::test(migrations = false)]
#[serial]
async fn test_create_user(pool: PgPool) -> anyhow::Result<()> {
let access_control = setup_access_control(&pool).await?;
let user = create_user(&pool).await?;
// Get workspace details
let workspace = database::workspace::select_user_workspace(&pool, &user.uuid)
.await?
.into_iter()
.next()
.ok_or(anyhow!("workspace should be created"))?;
assert!(access_control
.enforce(
&user.uid,
&ObjectType::Workspace(&workspace.workspace_id.to_string()),
AFRole::Owner
)
.await
.context("user should be owner of its workspace")?);
assert!(access_control
.enforce(
&user.uid,
&ObjectType::Collab(&workspace.workspace_id.to_string()),
AFAccessLevel::FullAccess,
)
.await
.context("user should have full access of its collab")?);
assert!(access_control
.enforce(
&user.uid,
&ObjectType::Collab(&workspace.workspace_id.to_string()),
Action::Read,
)
.await
.context("user should be able to read its collab")?);
assert!(access_control
.enforce(
&user.uid,
&ObjectType::Collab(&workspace.workspace_id.to_string()),
Action::Write,
)
.await
.context("user should be able to write its collab")?);
assert!(access_control
.enforce(
&user.uid,
&ObjectType::Collab(&workspace.workspace_id.to_string()),
Action::Delete,
)
.await
.context("user should be able to delete its collab")?);
Ok(())
}
#[sqlx::test(migrations = false)]
#[serial]
async fn test_add_users_to_workspace(pool: PgPool) -> anyhow::Result<()> {
let access_control = setup_access_control(&pool).await?;
let workspace_access_control = access_control.new_workspace_access_control();
let user_main = create_user(&pool).await?;
let user_owner = create_user(&pool).await?;
let user_member = create_user(&pool).await?;
let user_guest = create_user(&pool).await?;
// Get workspace details
let workspace = database::workspace::select_user_workspace(&pool, &user_main.uuid)
.await?
.into_iter()
.next()
.ok_or(anyhow!("workspace should be created"))?;
let members = vec![
CreateWorkspaceMember {
email: user_owner.email.clone(),
role: AFRole::Owner,
},
CreateWorkspaceMember {
email: user_member.email.clone(),
role: AFRole::Member,
},
CreateWorkspaceMember {
email: user_guest.email.clone(),
role: AFRole::Guest,
},
];
biz::workspace::ops::add_workspace_members(
&pool,
&user_main.uuid,
&workspace.workspace_id,
members,
&workspace_access_control,
)
.await
.context("adding users to workspace")?;
{
// Owner
let user = user_owner;
assert!(access_control
.enforce(
&user.uid,
&ObjectType::Collab(&workspace.workspace_id.to_string()),
AFAccessLevel::FullAccess,
)
.await
.context("owner should have full access of its collab")?);
assert!(access_control
.enforce(
&user.uid,
&ObjectType::Collab(&workspace.workspace_id.to_string()),
Action::Read,
)
.await
.context("user should be able to read its collab")?);
assert!(access_control
.enforce(
&user.uid,
&ObjectType::Collab(&workspace.workspace_id.to_string()),
Action::Write,
)
.await
.context("user should be able to write its collab")?);
assert!(access_control
.enforce(
&user.uid,
&ObjectType::Collab(&workspace.workspace_id.to_string()),
Action::Delete,
)
.await
.context("user should be able to delete its collab")?);
}
{
// Member
let user = user_member;
assert!(access_control
.enforce(
&user.uid,
&ObjectType::Collab(&workspace.workspace_id.to_string()),
AFAccessLevel::ReadAndWrite,
)
.await
.context("member should have read write access of its collab")?);
assert!(access_control
.enforce(
&user.uid,
&ObjectType::Collab(&workspace.workspace_id.to_string()),
Action::Read,
)
.await
.context("user should be able to read its collab")?);
assert!(access_control
.enforce(
&user.uid,
&ObjectType::Collab(&workspace.workspace_id.to_string()),
Action::Write,
)
.await
.context("user should be able to write its collab")?);
assert!(!access_control
.enforce(
&user.uid,
&ObjectType::Collab(&workspace.workspace_id.to_string()),
Action::Delete,
)
.await
.context("user should not be able to delete its collab")?);
}
{
// Guest
let user = user_guest;
assert!(access_control
.enforce(
&user.uid,
&ObjectType::Collab(&workspace.workspace_id.to_string()),
AFAccessLevel::ReadOnly,
)
.await
.context("guest should have read only access of its collab")?);
assert!(access_control
.enforce(
&user.uid,
&ObjectType::Collab(&workspace.workspace_id.to_string()),
Action::Read,
)
.await
.context("user should not be able to read its collab")?);
assert!(!access_control
.enforce(
&user.uid,
&ObjectType::Collab(&workspace.workspace_id.to_string()),
Action::Write,
)
.await
.context("user should not be able to write its collab")?);
assert!(!access_control
.enforce(
&user.uid,
&ObjectType::Collab(&workspace.workspace_id.to_string()),
Action::Delete,
)
.await
.context("user should not be able to delete its collab")?);
}
Ok(())
}

View file

@ -1,4 +1,3 @@
// mod access_control;
mod collab;
mod collab_snapshot;
mod gotrue;

View file

@ -1,23 +1,55 @@
use app_error::ErrorCode;
use client_api_test_util::TestClient;
use database_entity::dto::AFRole;
use database_entity::dto::{AFAccessLevel, AFRole, QueryCollabMembers};
use shared_entity::dto::workspace_dto::CreateWorkspaceMember;
#[tokio::test]
async fn add_workspace_members_not_enough_permission() {
async fn get_workspace_owner_after_sign_up_test() {
let c1 = TestClient::new_user_without_ws_conn().await;
let c2 = TestClient::new_user_without_ws_conn().await;
let c3 = TestClient::new_user_without_ws_conn().await;
let workspace_id = c1.workspace_id().await;
// after the user sign up, the user should be the owner of the workspace
let members = c1
.api_client
.get_workspace_members(&workspace_id)
.await
.unwrap();
assert_eq!(members.len(), 1);
assert_eq!(members[0].email, c1.email().await);
// after user sign up, the user should have full access to the workspace
let collab_members = c1
.api_client
.get_collab_members(QueryCollabMembers {
workspace_id: workspace_id.clone(),
object_id: workspace_id.clone(),
})
.await
.unwrap()
.0;
assert_eq!(collab_members.len(), 1);
assert_eq!(
collab_members[0].permission.access_level,
AFAccessLevel::FullAccess
);
}
#[tokio::test]
async fn add_workspace_members_not_enough_permission() {
let owner = TestClient::new_user_without_ws_conn().await;
let member_1 = TestClient::new_user_without_ws_conn().await;
let member_2 = TestClient::new_user_without_ws_conn().await;
let workspace_id = owner.workspace_id().await;
// add client 2 to client 1's workspace
c1.add_workspace_member(&workspace_id, &c2, AFRole::Member)
owner
.add_workspace_member(&workspace_id, &member_1, AFRole::Member)
.await;
// client 2 add client 3 to client 1's workspace but permission denied
let error = c2
.try_add_workspace_member(&workspace_id, &c3, AFRole::Member)
let error = member_1
.try_add_workspace_member(&workspace_id, &member_2, AFRole::Member)
.await
.unwrap_err();
assert_eq!(error.code, ErrorCode::NotEnoughPermissions);
@ -107,35 +139,86 @@ async fn update_workspace_member_role_from_guest_to_member() {
}
#[tokio::test]
async fn workspace_second_owner_add_member() {
let c1 = TestClient::new_user_without_ws_conn().await;
let c2 = TestClient::new_user_without_ws_conn().await;
let c3 = TestClient::new_user_without_ws_conn().await;
async fn workspace_add_member() {
let owner = TestClient::new_user_without_ws_conn().await;
let other_owner = TestClient::new_user_without_ws_conn().await;
let member = TestClient::new_user_without_ws_conn().await;
let guest = TestClient::new_user_without_ws_conn().await;
let workspace_id = c1.workspace_id().await;
let workspace_id = owner.workspace_id().await;
// add client 2 to client 1's workspace
c1.add_workspace_member(&workspace_id, &c2, AFRole::Owner)
owner
.add_workspace_member(&workspace_id, &other_owner, AFRole::Owner)
.await;
// add client 3 to client 1's workspace
c2.add_workspace_member(&workspace_id, &c3, AFRole::Member)
other_owner
.add_workspace_member(&workspace_id, &member, AFRole::Member)
.await;
other_owner
.add_workspace_member(&workspace_id, &guest, AFRole::Guest)
.await;
let members = c1
let members = owner
.api_client
.get_workspace_members(&workspace_id)
.await
.unwrap();
assert_eq!(members.len(), 3);
assert_eq!(members[0].email, c1.email().await);
assert_eq!(members.len(), 4);
assert_eq!(members[0].email, owner.email().await);
assert_eq!(members[0].role, AFRole::Owner);
assert_eq!(members[1].email, c2.email().await);
assert_eq!(members[1].email, other_owner.email().await);
assert_eq!(members[1].role, AFRole::Owner);
assert_eq!(members[2].email, c3.email().await);
assert_eq!(members[2].email, member.email().await);
assert_eq!(members[2].role, AFRole::Member);
assert_eq!(members[3].email, guest.email().await);
assert_eq!(members[3].role, AFRole::Guest);
// after adding the members to the workspace, we should be able to get the collab members
// of the workspace.
let collab_members = owner
.api_client
.get_collab_members(QueryCollabMembers {
workspace_id: workspace_id.clone(),
object_id: workspace_id.clone(),
})
.await
.unwrap()
.0;
assert_eq!(collab_members.len(), 4);
// owner
assert_eq!(collab_members[0].uid, owner.uid().await);
assert_eq!(
collab_members[0].permission.access_level,
AFAccessLevel::FullAccess
);
// other owner
assert_eq!(collab_members[1].uid, other_owner.uid().await);
assert_eq!(
collab_members[1].permission.access_level,
AFAccessLevel::FullAccess
);
// member
assert_eq!(collab_members[2].uid, member.uid().await);
assert_eq!(
collab_members[2].permission.access_level,
AFAccessLevel::ReadAndWrite
);
// guest
assert_eq!(collab_members[3].uid, guest.uid().await);
assert_eq!(
collab_members[3].permission.access_level,
AFAccessLevel::ReadOnly
);
}
#[tokio::test]