feat: sync document through http request (#1064)

* chore: query embedding

* chore: create embeddings

* chore: apply update to editing collab

* refactor: web-update

* chore: calculate missing update the sv is not none

* chore: add test

* chore: fix audit

* chore: commit sqlx

* chore: fix client api

* test: add

* chore: clippy

* chore: fix collab drop when save
This commit is contained in:
Nathan.fooo 2024-12-12 14:53:07 +08:00 committed by GitHub
parent b4a0669361
commit af38efe6d9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
78 changed files with 1781 additions and 1019 deletions

View file

@ -29,7 +29,7 @@ env:
POSTGRES_PASSWORD: password
DATABASE_URL: postgres://postgres:password@localhost:5432/postgres
SQLX_OFFLINE: true
RUST_TOOLCHAIN: "1.78"
RUST_TOOLCHAIN: "1.80"
jobs:
setup:

View file

@ -11,7 +11,7 @@ env:
SQLX_VERSION: 0.7.1
SQLX_FEATURES: "rustls,postgres"
SQLX_OFFLINE: true
RUST_TOOLCHAIN: "1.78"
RUST_TOOLCHAIN: "1.80"
jobs:
test:

View file

@ -20,7 +20,7 @@ on:
env:
NODE_VERSION: '20.12.0'
RUST_TOOLCHAIN: "1.78.0"
RUST_TOOLCHAIN: "1.80.0"
jobs:
publish:
runs-on: ubuntu-latest

View file

@ -0,0 +1,29 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT\n oid AS object_id,\n indexed_at\n FROM af_collab_embeddings\n WHERE oid = $1 AND partition_key = $2\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "object_id",
"type_info": "Text"
},
{
"ordinal": 1,
"name": "indexed_at",
"type_info": "Timestamp"
}
],
"parameters": {
"Left": [
"Text",
"Int4"
]
},
"nullable": [
false,
false
]
},
"hash": "567706898cc802c6ec72a95084a69d93277fe34650b9e2d2f58854d0ab4b7d8e"
}

367
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -58,8 +58,8 @@ reqwest = { workspace = true, features = [
unicode-segmentation = "1.10"
lazy_static.workspace = true
fancy-regex = "0.11.0"
bytes.workspace = true
validator.workspace = true
bytes = "1.5.0"
rcgen = { version = "0.10.0", features = ["pem", "x509-parser"] }
mime = "0.3.17"
aws-sdk-s3 = { version = "1.63.0", features = [
@ -158,6 +158,7 @@ console-subscriber = { version = "0.4.1", optional = true }
base64.workspace = true
md5.workspace = true
nanoid = "0.4.0"
http = "0.2.12"
[dev-dependencies]
once_cell = "1.19.0"
@ -245,7 +246,7 @@ secrecy = { version = "0.8", features = ["serde"] }
serde_json = "1.0.111"
serde_repr = "0.1.18"
serde = { version = "1.0.195", features = ["derive"] }
bytes = "1.5.0"
bytes = "1.9.0"
workspace-template = { path = "libs/workspace-template" }
uuid = { version = "1.6.1", features = ["v4", "v5"] }
anyhow = "1.0.94"
@ -308,7 +309,14 @@ debug = true
[profile.ci]
inherits = "release"
opt-level = 2
lto = false # Disable Link-Time Optimization
lto = false
[profile.dev]
opt-level = 0
lto = false
codegen-units = 128
incremental = true
debug = true
[patch.crates-io]
# It's diffcult to resovle different version with the same crate used in AppFlowy Frontend and the Client-API crate.

View file

@ -1,2 +1,2 @@
[advisories]
ignore = ["RUSTSEC-2024-0370", "RUSTSEC-2024-0384"]
ignore = ["RUSTSEC-2024-0384"]

View file

@ -11,7 +11,7 @@ crate-type = ["cdylib", "rlib"]
thiserror = "1.0.56"
serde_repr = "0.1.18"
serde.workspace = true
anyhow = "1.0.79"
anyhow.workspace = true
uuid = { workspace = true, features = ["v4"] }
sqlx = { workspace = true, default-features = false, features = [
"postgres",

View file

@ -177,6 +177,12 @@ pub enum AppError {
#[error("{0}")]
ServiceTemporaryUnavailable(String),
#[error("Decode update error: {0}")]
DecodeUpdateError(String),
#[error("Apply update error:{0}")]
ApplyUpdateError(String),
}
impl AppError {
@ -254,6 +260,8 @@ impl AppError {
ErrorCode::CustomNamespaceInvalidCharacter
},
AppError::ServiceTemporaryUnavailable(_) => ErrorCode::ServiceTemporaryUnavailable,
AppError::DecodeUpdateError(_) => ErrorCode::DecodeUpdateError,
AppError::ApplyUpdateError(_) => ErrorCode::ApplyUpdateError,
}
}
}
@ -394,6 +402,8 @@ pub enum ErrorCode {
PublishNameTooLong = 1052,
CustomNamespaceInvalidCharacter = 1053,
ServiceTemporaryUnavailable = 1054,
DecodeUpdateError = 1055,
ApplyUpdateError = 1056,
}
impl ErrorCode {

View file

@ -15,11 +15,11 @@ reqwest = { version = "0.12", features = [
serde = { version = "1.0.199", features = ["derive"], optional = true }
serde_json = { version = "1.0", optional = true }
thiserror = "1.0.58"
anyhow = "1.0.81"
anyhow.workspace = true
tracing = { version = "0.1", optional = true }
serde_repr = { version = "0.1", optional = true }
futures = "0.3.30"
bytes = "1.6.0"
bytes.workspace = true
pin-project = "1.1.5"
[dev-dependencies]

View file

@ -6,7 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
bytes = "1.5.0"
bytes.workspace = true
mime = "0.3.17"
serde_json = "1.0.111"
tokio = { workspace = true, features = ["sync"] }
@ -35,7 +35,7 @@ reqwest.workspace = true
gotrue.workspace = true
client-websocket.workspace = true
futures = "0.3.30"
anyhow = "1.0.80"
anyhow.workspace = true
serde = { version = "1.0.199", features = ["derive"] }
hex = "0.4.3"
async-trait.workspace = true

View file

@ -65,6 +65,17 @@ pub struct TestCollab {
pub origin: CollabOrigin,
pub collab: Arc<RwLock<dyn BorrowMut<Collab> + Send + Sync + 'static>>,
}
impl TestCollab {
pub async fn encode_collab(&self) -> EncodedCollab {
let lock = self.collab.read().await;
let collab = (*lock).borrow();
collab
.encode_collab_v1(|_| Ok::<(), anyhow::Error>(()))
.unwrap()
}
}
impl TestClient {
pub async fn new(registered_user: User, start_ws_conn: bool) -> Self {
load_env();
@ -1026,7 +1037,7 @@ pub async fn assert_server_collab(
};
if timeout(duration, operation).await.is_err() {
eprintln!("json : {}, expected: {}", final_json.lock().await, expected);
eprintln!("json:{}\nexpected:{}", final_json.lock().await, expected);
return Err(anyhow!("time out for the action"));
}
Ok(())

View file

@ -9,12 +9,12 @@ crate-type = ["cdylib", "rlib"]
[dependencies]
reqwest = { workspace = true, features = ["multipart"] }
anyhow = "1.0.79"
anyhow.workspace = true
serde_repr = "0.1.18"
gotrue = { path = "../gotrue" }
tracing = { version = "0.1" }
thiserror = "1.0.56"
bytes = "1.5"
bytes = "1.9.0"
uuid.workspace = true
futures-util = "0.3.30"
futures-core = "0.3.30"
@ -29,6 +29,8 @@ tokio-stream = { version = "0.1.14" }
chrono = "0.4"
client-websocket = { workspace = true, features = ["native-tls"] }
semver = "1.0.22"
zstd = { version = "0.13.2" }
collab = { workspace = true, optional = true }
yrs = { workspace = true, optional = true }

View file

@ -244,7 +244,6 @@ where
}
}
}
// Check if all non-ping messages have been sent
let all_non_ping_messages_sent = !message_queue
.iter()

View file

@ -1,7 +1,6 @@
mod collab_sink;
mod collab_stream;
mod error;
mod period_state_check;
mod plugin;
mod sync_control;

View file

@ -1,86 +0,0 @@
use crate::collab_sync::{CollabSinkState, SinkQueue, SinkSignal};
use collab::core::origin::CollabOrigin;
use collab_rt_entity::{ClientCollabMessage, CollabStateCheck, SinkMessage};
use std::sync::atomic::Ordering;
use std::sync::{Arc, Weak};
use std::time::Duration;
use tokio::sync::watch;
use tokio::time::{sleep_until, Instant};
use tracing::warn;
#[allow(dead_code)]
pub struct CollabStateCheckRunner;
impl CollabStateCheckRunner {
#[allow(dead_code)]
pub(crate) fn run(
origin: CollabOrigin,
object_id: String,
message_queue: Weak<parking_lot::Mutex<SinkQueue<ClientCollabMessage>>>,
weak_notify: Weak<watch::Sender<SinkSignal>>,
state: Arc<CollabSinkState>,
) {
let duration = if cfg!(feature = "test_fast_sync") {
Duration::from_secs(10)
} else {
Duration::from_secs(20)
};
let mut next_tick = Instant::now() + duration;
tokio::spawn(async move {
loop {
sleep_until(next_tick).await;
// Set the next tick to the current time plus the duration.
// Otherwise, it might spike the CPU usage.
next_tick = Instant::now() + duration;
match message_queue.upgrade() {
None => {
if cfg!(feature = "sync_verbose_log") {
tracing::warn!("{} message queue dropped", object_id);
}
break;
},
Some(message_queue) => {
if state.pause_ping.load(Ordering::SeqCst) {
continue;
} else {
// Skip this iteration if a message was sent recently, within the specified duration.
if !state.latest_sync.is_time_for_next_sync(duration).await {
continue;
}
if let Some(mut queue) = message_queue.try_lock() {
let is_not_empty = queue.iter().any(|item| !item.message().is_ping_sync());
if is_not_empty {
if cfg!(feature = "sync_verbose_log") {
tracing::trace!("{} slow down check", object_id);
}
next_tick = Instant::now() + Duration::from_secs(30);
}
let msg_id = state.id_counter.next();
let check = CollabStateCheck {
origin: origin.clone(),
object_id: object_id.clone(),
msg_id,
};
queue.push_msg(msg_id, ClientCollabMessage::ClientCollabStateCheck(check));
// notify the sink to proceed next message
if let Some(notify) = weak_notify.upgrade() {
if let Err(err) = notify.send(SinkSignal::Proceed) {
warn!("{} fail to send notify signal: {}", object_id, err);
break;
}
}
}
}
},
}
}
});
}
}

View file

@ -1086,7 +1086,7 @@ impl Client {
let headers = [
("client-version", self.client_version.to_string()),
("client-timestamp", ts_now.to_string()),
("device_id", self.device_id.clone()),
("device-id", self.device_id.clone()),
("ai-model", self.ai_model.read().to_str().to_string()),
];
trace!(

View file

@ -1,5 +1,7 @@
use crate::entity::CollabType;
use crate::http::log_request_id;
use crate::{blocking_brotli_compress, brotli_compress, Client};
use anyhow::anyhow;
use app_error::AppError;
use bytes::Bytes;
use chrono::{DateTime, Utc};
@ -8,9 +10,10 @@ use client_api_entity::workspace_dto::{
DatabaseRowUpdatedItem, ListDatabaseRowDetailParam, ListDatabaseRowUpdatedParam,
};
use client_api_entity::{
BatchQueryCollabParams, BatchQueryCollabResult, CollabParams, CreateCollabParams,
AFCollabInfo, BatchQueryCollabParams, BatchQueryCollabResult, CollabParams, CreateCollabParams,
DeleteCollabParams, PublishCollabItem, QueryCollab, QueryCollabParams, UpdateCollabWebParams,
};
use collab_rt_entity::collab_proto::{CollabDocStateParams, PayloadCompressionType};
use collab_rt_entity::HttpRealtimeMessage;
use futures::Stream;
use futures_util::stream;
@ -21,6 +24,7 @@ use serde::Serialize;
use shared_entity::dto::workspace_dto::{CollabResponse, CollabTypeParam};
use shared_entity::response::{AppResponse, AppResponseError};
use std::future::Future;
use std::io::Cursor;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
@ -403,6 +407,77 @@ impl Client {
.await?;
AppResponse::<()>::from_response(resp).await?.into_error()
}
pub async fn get_collab_info(
&self,
workspace_id: &str,
object_id: &str,
collab_type: CollabType,
) -> Result<AFCollabInfo, AppResponseError> {
let url = format!(
"{}/api/workspace/{}/collab/{}/info",
self.base_url, workspace_id, object_id
);
let resp = self
.http_client_with_auth(Method::GET, &url)
.await?
.query(&CollabTypeParam { collab_type })
.send()
.await?;
log_request_id(&resp);
AppResponse::<AFCollabInfo>::from_response(resp)
.await?
.into_data()
}
pub async fn post_collab_doc_state(
&self,
workspace_id: &str,
object_id: &str,
collab_type: CollabType,
doc_state: Vec<u8>,
state_vector: Vec<u8>,
) -> Result<Vec<u8>, AppResponseError> {
let url = format!(
"{}/api/workspace/v1/{workspace_id}/collab/{object_id}/sync",
self.base_url
);
// 3 is default level
let doc_state = zstd::encode_all(Cursor::new(doc_state), 3)
.map_err(|err| AppError::InvalidRequest(format!("Failed to compress text: {}", err)))?;
let sv = zstd::encode_all(Cursor::new(state_vector), 3)
.map_err(|err| AppError::InvalidRequest(format!("Failed to compress text: {}", err)))?;
let params = CollabDocStateParams {
object_id: object_id.to_string(),
collab_type: collab_type.value(),
compression: PayloadCompressionType::Zstd as i32,
sv,
doc_state,
};
let mut encoded_payload = Vec::new();
params.encode(&mut encoded_payload).map_err(|err| {
AppError::Internal(anyhow!("Failed to encode CollabDocStateParams: {}", err))
})?;
let resp = self
.http_client_with_auth(Method::POST, &url)
.await?
.body(Bytes::from(encoded_payload))
.send()
.await?;
log_request_id(&resp);
if resp.status().is_success() {
let body = resp.bytes().await?;
let decompressed_body = zstd::decode_all(Cursor::new(body))?;
Ok(decompressed_body)
} else {
AppResponse::from_response(resp).await?.into_data()
}
}
}
struct RetryGetCollabCondition;

View file

@ -2,8 +2,8 @@ use crate::http::log_request_id;
use crate::Client;
use client_api_entity::{
AFCollabMember, AFCollabMembers, AFWorkspaceInvitation, AFWorkspaceInvitationStatus,
AFWorkspaceMember, CollabMemberIdentify, InsertCollabMemberParams, QueryCollabMembers,
QueryWorkspaceMember, UpdateCollabMemberParams,
AFWorkspaceMember, InsertCollabMemberParams, QueryCollabMembers, QueryWorkspaceMember,
UpdateCollabMemberParams, WorkspaceCollabIdentify,
};
use reqwest::Method;
use shared_entity::dto::workspace_dto::{
@ -205,7 +205,7 @@ impl Client {
#[instrument(level = "info", skip_all, err)]
pub async fn get_collab_member(
&self,
params: CollabMemberIdentify,
params: WorkspaceCollabIdentify,
) -> Result<AFCollabMember, AppResponseError> {
let url = format!(
"{}/api/workspace/{}/collab/{}/member",
@ -245,7 +245,7 @@ impl Client {
#[instrument(level = "info", skip_all, err)]
pub async fn remove_collab_member(
&self,
params: CollabMemberIdentify,
params: WorkspaceCollabIdentify,
) -> Result<(), AppResponseError> {
let url = format!(
"{}/api/workspace/{}/collab/{}/member",

View file

@ -302,7 +302,7 @@ impl WSClient {
RealtimeMessage::ServerCollabV1(collab_messages) => {
handle_collab_message(&weak_collab_channels, collab_messages);
},
RealtimeMessage::ClientCollabV1(_) | RealtimeMessage::ClientCollabV2(_) => {
RealtimeMessage::ClientCollabV2(_) | RealtimeMessage::ClientCollabV1(_) => {
// The message from server should not be collab message.
error!(
"received unexpected collab message from websocket: {:?}",

View file

@ -8,8 +8,8 @@ use tokio::time::{sleep_until, Instant};
use tracing::{error, trace};
use client_websocket::Message;
use collab_rt_entity::RealtimeMessage;
use collab_rt_entity::{ClientCollabMessage, MsgId};
use collab_rt_entity::{MessageByObjectId, RealtimeMessage};
pub type AggregateMessagesSender = mpsc::Sender<Message>;
pub type AggregateMessagesReceiver = mpsc::Receiver<Message>;
@ -115,7 +115,7 @@ async fn send_batch_message(
sender: &AggregateMessagesSender,
messages_map: HashMap<String, Vec<ClientCollabMessage>>,
) {
match RealtimeMessage::ClientCollabV2(messages_map).encode() {
match RealtimeMessage::ClientCollabV2(MessageByObjectId(messages_map)).encode() {
Ok(data) => {
if let Err(e) = sender.send(Message::Binary(data)).await {
trace!("websocket channel close:{}, stop sending messages", e);

View file

@ -13,7 +13,7 @@ collab-entity = { workspace = true }
serde.workspace = true
serde_json.workspace = true
bytes = { version = "1.5", features = ["serde"] }
anyhow = "1.0.79"
anyhow.workspace = true
actix = { version = "0.13", optional = true }
bincode.workspace = true
tokio-tungstenite = { version = "0.20.1", optional = true }

View file

@ -12,7 +12,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
prost_build::Config::new()
.out_dir("src/")
.compile_protos(&["proto/realtime.proto"], &["proto/"])?;
.compile_protos(&["proto/realtime.proto", "proto/collab.proto"], &["proto/"])?;
// Run rustfmt on the generated files.
let files = std::fs::read_dir("src/")?

View file

@ -0,0 +1,16 @@
syntax = "proto3";
package collab_proto;
enum PayloadCompressionType {
NONE = 0;
ZSTD = 1;
}
message CollabDocStateParams {
string object_id = 1;
int32 collab_type = 2;
PayloadCompressionType compression = 3;
bytes sv = 4;
bytes doc_state = 5;
}

View file

@ -1,6 +1,6 @@
use crate::message::RealtimeMessage;
use crate::server_message::ServerInit;
use crate::{CollabMessage, MsgId};
use crate::{CollabMessage, MessageByObjectId, MsgId};
use anyhow::{anyhow, Error};
use bytes::Bytes;
use collab::core::origin::CollabOrigin;
@ -44,9 +44,6 @@ impl ClientCollabMessage {
Self::ServerInitSync(data)
}
pub fn new_awareness_sync(data: UpdateSync) -> Self {
Self::ClientAwarenessSync(data)
}
pub fn size(&self) -> usize {
match self {
ClientCollabMessage::ClientInitSync { data, .. } => data.payload.len(),
@ -136,7 +133,7 @@ impl TryFrom<CollabMessage> for ClientCollabMessage {
},
CollabMessage::ServerInitSync(msg) => Ok(ClientCollabMessage::ServerInitSync(msg)),
_ => Err(anyhow!(
"Can't convert to ClientCollabMessage for given collab message:{}",
"Can't convert to ClientCollabMessage for value:{}",
value
)),
}
@ -146,7 +143,8 @@ impl TryFrom<CollabMessage> for ClientCollabMessage {
impl From<ClientCollabMessage> for RealtimeMessage {
fn from(msg: ClientCollabMessage) -> Self {
let object_id = msg.object_id().to_string();
Self::ClientCollabV2([(object_id, vec![msg])].into())
let message = MessageByObjectId::new_with_message(object_id, vec![msg]);
Self::ClientCollabV2(message)
}
}
@ -186,6 +184,7 @@ impl SinkMessage for ClientCollabMessage {
fn is_update_sync(&self) -> bool {
matches!(self, ClientCollabMessage::ClientUpdateSync { .. })
}
fn is_ping_sync(&self) -> bool {
matches!(self, ClientCollabMessage::ClientCollabStateCheck { .. })
}

View file

@ -0,0 +1,41 @@
// This file is @generated by prost-build.
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct CollabDocStateParams {
#[prost(string, tag = "1")]
pub object_id: ::prost::alloc::string::String,
#[prost(int32, tag = "2")]
pub collab_type: i32,
#[prost(enumeration = "PayloadCompressionType", tag = "3")]
pub compression: i32,
#[prost(bytes = "vec", tag = "4")]
pub sv: ::prost::alloc::vec::Vec<u8>,
#[prost(bytes = "vec", tag = "5")]
pub doc_state: ::prost::alloc::vec::Vec<u8>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
#[repr(i32)]
pub enum PayloadCompressionType {
None = 0,
Zstd = 1,
}
impl PayloadCompressionType {
/// String value of the enum field names used in the ProtoBuf definition.
///
/// The values are not transformed in any way and thus are considered stable
/// (if the ProtoBuf definition does not change) and safe for programmatic use.
pub fn as_str_name(&self) -> &'static str {
match self {
PayloadCompressionType::None => "NONE",
PayloadCompressionType::Zstd => "ZSTD",
}
}
/// Creates an enum from field names used in the ProtoBuf definition.
pub fn from_str_name(value: &str) -> ::core::option::Option<Self> {
match value {
"NONE" => Some(Self::None),
"ZSTD" => Some(Self::Zstd),
_ => None,
}
}
}

View file

@ -8,6 +8,7 @@ mod client_message;
// cargo clean
// cargo build
// ```
pub mod collab_proto;
pub mod realtime_proto;
mod server_message;

View file

@ -14,6 +14,7 @@ use serde::{Deserialize, Serialize};
use std::fmt::{Display, Formatter};
#[cfg(feature = "rt_compress")]
use std::io::Read;
use std::ops::{Deref, DerefMut};
/// Maximum allowable size for a realtime message.
///
@ -27,7 +28,32 @@ pub const MAXIMUM_REALTIME_MESSAGE_SIZE: u64 = 10 * 1024 * 1024; // 10 MB
#[cfg(feature = "rt_compress")]
const COMPRESSED_PREFIX: &[u8] = b"COMPRESSED:1";
pub type MessageByObjectId = HashMap<String, Vec<ClientCollabMessage>>;
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct MessageByObjectId(pub HashMap<String, Vec<ClientCollabMessage>>);
impl MessageByObjectId {
pub fn new_with_message(object_id: String, messages: Vec<ClientCollabMessage>) -> Self {
let mut map = HashMap::with_capacity(1);
map.insert(object_id, messages);
Self(map)
}
pub fn into_inner(self) -> HashMap<String, Vec<ClientCollabMessage>> {
self.0
}
}
impl Deref for MessageByObjectId {
type Target = HashMap<String, Vec<ClientCollabMessage>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for MessageByObjectId {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(
@ -36,47 +62,30 @@ pub type MessageByObjectId = HashMap<String, Vec<ClientCollabMessage>>;
rtype(result = "()")
)]
pub enum RealtimeMessage {
Collab(CollabMessage),
Collab(CollabMessage), // Deprecated
User(UserMessage),
System(SystemMessage),
ClientCollabV1(Vec<ClientCollabMessage>),
ClientCollabV1(Vec<ClientCollabMessage>), // Deprecated
ClientCollabV2(MessageByObjectId),
ServerCollabV1(Vec<ServerCollabMessage>),
}
impl RealtimeMessage {
pub fn size(&self) -> usize {
match self {
RealtimeMessage::Collab(msg) => msg.len(),
RealtimeMessage::User(_) => 1,
RealtimeMessage::System(_) => 1,
RealtimeMessage::ClientCollabV1(msgs) => msgs.iter().map(|msg| msg.size()).sum(),
RealtimeMessage::ClientCollabV2(msgs) => msgs
.iter()
.map(|(_, value)| value.iter().map(|v| v.size()).sum::<usize>())
.sum(),
RealtimeMessage::ServerCollabV1(msgs) => msgs.iter().map(|msg| msg.size()).sum(),
}
}
/// Convert RealtimeMessage to ClientCollabMessage
/// If the message is not a collab message, it will return an empty vec
/// If the message is a collab message, it will return a vec with one element
/// If the message is a ClientCollabV1, it will return list of collab messages
pub fn transform(self) -> Result<MessageByObjectId, Error> {
pub fn split_messages_by_object_id(self) -> Result<MessageByObjectId, Error> {
match self {
RealtimeMessage::Collab(collab_message) => {
let object_id = collab_message.object_id().to_string();
let collab_message = ClientCollabMessage::try_from(collab_message)?;
Ok([(object_id, vec![collab_message])].into())
},
RealtimeMessage::ClientCollabV1(collab_messages) => {
let message_map: MessageByObjectId = collab_messages
.into_iter()
.map(|message| (message.object_id().to_string(), vec![message]))
.collect();
Ok(message_map)
let message = MessageByObjectId::new_with_message(
object_id,
vec![ClientCollabMessage::try_from(collab_message)?],
);
Ok(message)
},
RealtimeMessage::ClientCollabV1(_) => Err(anyhow!("ClientCollabV1 is not supported")),
RealtimeMessage::ClientCollabV2(collab_messages) => Ok(collab_messages),
_ => Err(anyhow!(
"Failed to convert RealtimeMessage:{} to ClientCollabMessage",

View file

@ -2,7 +2,7 @@ use bytes::Bytes;
use collab::core::origin::CollabOrigin;
use collab_entity::CollabType;
use collab_rt_entity::user::UserMessage;
use collab_rt_entity::{ClientCollabMessage, CollabMessage, InitSync, MsgId};
use collab_rt_entity::{CollabMessage, InitSync, MsgId};
use collab_rt_entity::{RealtimeMessage, SystemMessage};
use serde::{Deserialize, Serialize};
use std::fs::File;
@ -50,24 +50,6 @@ fn decode_0149_realtime_message_test() {
} else {
panic!("Failed to decode RealtimeMessage from file");
}
let client_collab_v1 = read_message_from_file("migration/0149/client_collab_v1").unwrap();
assert!(matches!(
client_collab_v1,
RealtimeMessage::ClientCollabV1(_)
));
if let RealtimeMessage::ClientCollabV1(messages) = client_collab_v1 {
assert_eq!(messages.len(), 1);
if let ClientCollabMessage::ClientUpdateSync { data } = &messages[0] {
assert_eq!(data.object_id, "object id 1");
assert_eq!(data.msg_id, 10);
assert_eq!(data.payload, Bytes::from(vec![5, 6, 7, 8]));
} else {
panic!("Failed to decode RealtimeMessage from file");
}
} else {
panic!("Failed to decode RealtimeMessage from file");
}
}
#[test]

View file

@ -10,7 +10,7 @@ redis = { workspace = true, features = ["aio", "tokio-comp", "connection-manager
tokio = { version = "1.26", features = ["rt-multi-thread", "macros"] }
tokio-stream = { version = "0.1.14" }
thiserror = "1.0.58"
anyhow = "1.0.81"
anyhow.workspace = true
futures = "0.3.30"
tracing = "0.1"
serde = { version = "1", features = ["derive"] }

View file

@ -15,7 +15,7 @@ validator = { workspace = true, features = ["validator_derive", "derive"] }
chrono = { version = "0.4", features = ["serde"] }
uuid = { workspace = true, features = ["serde", "v4"] }
thiserror = "1.0.56"
anyhow = "1.0.79"
anyhow.workspace = true
tracing = "0.1"
serde_repr = "0.1.18"
app-error = { workspace = true }

View file

@ -379,7 +379,7 @@ pub struct InsertCollabMemberParams {
pub type UpdateCollabMemberParams = InsertCollabMemberParams;
#[derive(Debug, Clone, Validate, Serialize, Deserialize)]
pub struct CollabMemberIdentify {
pub struct WorkspaceCollabIdentify {
pub uid: i64,
#[validate(custom(function = "validate_not_empty_str"))]
pub workspace_id: String,
@ -427,6 +427,13 @@ pub struct AFCollabMember {
pub permission: AFPermission,
}
#[derive(Serialize, Deserialize)]
pub struct AFCollabInfo {
pub object_id: String,
/// The timestamp when the object embeddings updated
pub embedding_index_at: DateTime<Utc>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct PublishInfo {
pub namespace: String,

View file

@ -16,7 +16,7 @@ app-error = { workspace = true, features = ["sqlx_error", "validation_error"] }
tokio = { workspace = true, features = ["sync"] }
async-trait.workspace = true
anyhow = "1.0.79"
anyhow.workspace = true
serde.workspace = true
serde_json.workspace = true
tonic-proto.workspace = true
@ -40,7 +40,7 @@ redis = { workspace = true, features = [
"connection-manager",
] }
futures-util = "0.3.30"
bytes = "1.5"
bytes.workspace = true
aws-sdk-s3 = { version = "1.36.0", features = [
"behavior-version-latest",
"rt-tokio",

View file

@ -1,8 +1,8 @@
use anyhow::{anyhow, Context};
use collab_entity::CollabType;
use database_entity::dto::{
AFAccessLevel, AFCollabMember, AFPermission, AFSnapshotMeta, AFSnapshotMetas, CollabParams,
QueryCollab, QueryCollabResult, RawData,
AFAccessLevel, AFCollabInfo, AFCollabMember, AFPermission, AFSnapshotMeta, AFSnapshotMetas,
CollabParams, QueryCollab, QueryCollabResult, RawData,
};
use shared_entity::dto::workspace_dto::DatabaseRowUpdatedItem;
@ -749,3 +749,33 @@ pub async fn select_last_updated_database_row_ids(
.await?;
Ok(updated_row_items)
}
pub async fn get_collab_info<'a, E>(
tx: E,
object_id: &str,
collab_type: CollabType,
) -> Result<Option<AFCollabInfo>, sqlx::Error>
where
E: Executor<'a, Database = Postgres>,
{
let partition_key = crate::collab::partition_key_from_collab_type(&collab_type);
let result = sqlx::query!(
r#"
SELECT
oid AS object_id,
indexed_at
FROM af_collab_embeddings
WHERE oid = $1 AND partition_key = $2
"#,
object_id,
partition_key
)
.fetch_optional(tx)
.await?
.map(|row| AFCollabInfo {
object_id: row.object_id,
embedding_index_at: DateTime::<Utc>::from_naive_utc_and_offset(row.indexed_at, Utc),
});
Ok(result)
}

View file

@ -7,7 +7,6 @@ use database_entity::dto::{
};
use collab::entity::EncodedCollab;
use collab_rt_entity::ClientCollabMessage;
use serde::{Deserialize, Serialize};
use sqlx::Transaction;
use std::collections::HashMap;
@ -64,7 +63,7 @@ pub trait CollabStorage: Send + Sync + 'static {
/// * `workspace_id` - The ID of the workspace.
/// * `uid` - The ID of the user.
/// * `params` - The parameters containing the data of the collaboration.
/// * `write_immediately` - A boolean value that indicates whether the data should be written immediately.
/// * `flush_to_disk` - A boolean value that indicates whether the data should be written immediately.
/// if write_immediately is true, the data will be written to disk immediately. Otherwise, the data will
/// be scheduled to be written to disk later.
///
@ -73,7 +72,7 @@ pub trait CollabStorage: Send + Sync + 'static {
workspace_id: &str,
uid: &i64,
params: CollabParams,
write_immediately: bool,
flush_to_disk: bool,
) -> AppResult<()>;
async fn batch_insert_new_collab(
@ -117,16 +116,6 @@ pub trait CollabStorage: Send + Sync + 'static {
from_editing_collab: bool,
) -> AppResult<EncodedCollab>;
/// Sends a collab message to all connected clients.
/// # Arguments
/// * `object_id` - The ID of the collaboration object.
/// * `collab_messages` - The list of collab messages to broadcast.
async fn broadcast_encode_collab(
&self,
object_id: String,
collab_messages: Vec<ClientCollabMessage>,
) -> Result<(), AppError>;
async fn batch_get_collab(
&self,
uid: &i64,

View file

@ -1,9 +1,8 @@
use std::ops::DerefMut;
use collab_entity::CollabType;
use pgvector::Vector;
use sqlx::postgres::{PgHasArrayType, PgTypeInfo};
use sqlx::{Error, Executor, Postgres, Transaction};
use std::ops::DerefMut;
use uuid::Uuid;
use database_entity::dto::{

View file

@ -6,10 +6,10 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
x25519-dalek = { version = "2.0.0" , features = ["getrandom"] }
x25519-dalek = { version = "2.0.0", features = ["getrandom"] }
rand = "0.8.5"
hex = "0.4.3"
anyhow = "1.0.79"
anyhow.workspace = true
aes-gcm = { version = "0.10.3" }
base64 = "0.21.7"
hkdf = { version = "0.12.4" }

View file

@ -8,7 +8,7 @@ edition = "2021"
[dependencies]
serde.workspace = true
serde_json.workspace = true
anyhow = "1.0.79"
anyhow.workspace = true
lazy_static = "1.4.0"
jsonwebtoken = "8.3.0"
app-error = { workspace = true, features = ["gotrue_error"] }

View file

@ -11,7 +11,7 @@ crate-type = ["cdylib", "rlib"]
serde.workspace = true
serde_json.workspace = true
futures-util = "0.3.30"
anyhow = "1.0.79"
anyhow.workspace = true
reqwest = { workspace = true, features = ["json", "rustls-tls", "cookies"] }
tokio = { workspace = true, features = ["sync", "macros"] }
infra = { path = "../infra", features = ["request_util"] }

View file

@ -7,7 +7,7 @@ edition = "2021"
[dependencies]
reqwest = { workspace = true, optional = true }
anyhow = "1.0.79"
anyhow.workspace = true
serde.workspace = true
serde_json.workspace = true
tracing.workspace = true

View file

@ -8,7 +8,7 @@ edition = "2021"
crate-type = ["cdylib", "rlib"]
[dependencies]
anyhow = "1.0.79"
anyhow.workspace = true
serde = "1.0.195"
serde_json.workspace = true
serde_repr = "0.1.18"
@ -31,7 +31,7 @@ actix-web = { version = "4.4.1", default-features = false, features = [
], optional = true }
validator = { workspace = true, features = ["validator_derive", "derive"] }
futures = "0.3.30"
bytes = "1.6.0"
bytes.workspace = true
log = "0.4.21"
tracing = { workspace = true }

View file

@ -120,7 +120,7 @@ pub struct PatchWorkspaceParam {
pub workspace_icon: Option<String>,
}
#[derive(Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
pub struct CollabTypeParam {
pub collab_type: CollabType,
}

View file

@ -1,2 +1,2 @@
[toolchain]
channel = "1.83.0"
channel = "1.80.0"

View file

@ -3,7 +3,7 @@
# Generate the current dependency list
cargo tree > current_deps.txt
BASELINE_COUNT=720
BASELINE_COUNT=722
CURRENT_COUNT=$(cat current_deps.txt | wc -l)
echo "Expected dependency count (baseline): $BASELINE_COUNT"

View file

@ -59,7 +59,7 @@ sqlx = { workspace = true, default-features = false, features = [
] }
thiserror = "1.0.56"
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
anyhow = "1"
anyhow.workspace = true
bytes.workspace = true
collab = { workspace = true }

View file

@ -1,4 +1,4 @@
use crate::actix_ws::entities::{ClientMessage, Connect, Disconnect, RealtimeMessage};
use crate::actix_ws::entities::{ClientWebSocketMessage, Connect, Disconnect, RealtimeMessage};
use crate::error::RealtimeError;
use crate::RealtimeClientWebsocketSink;
use actix::{
@ -27,7 +27,7 @@ use tracing::{debug, error, trace, warn};
pub type HandlerResult = anyhow::Result<(), RealtimeError>;
pub trait RealtimeServer:
Actor<Context = Context<Self>>
+ Handler<ClientMessage, Result = HandlerResult>
+ Handler<ClientWebSocketMessage, Result = HandlerResult>
+ Handler<Connect, Result = HandlerResult>
+ Handler<Disconnect, Result = HandlerResult>
{
@ -103,7 +103,7 @@ where
self
.server
.try_send(ClientMessage {
.try_send(ClientWebSocketMessage {
user: self.user.clone(),
message,
})
@ -127,7 +127,7 @@ where
let fut = async move {
match tokio::task::spawn_blocking(move || RealtimeMessage::decode(&bytes)).await {
Ok(Ok(decoded_message)) => {
let mut client_message = Some(ClientMessage {
let mut client_message = Some(ClientWebSocketMessage {
user,
message: decoded_message,
});

View file

@ -1,12 +1,13 @@
use crate::error::RealtimeError;
use actix::{Message, Recipient};
use app_error::AppError;
use serde_repr::{Deserialize_repr, Serialize_repr};
use std::fmt::Debug;
use bytes::Bytes;
use collab_entity::CollabType;
use collab_rt_entity::user::RealtimeUser;
pub use collab_rt_entity::RealtimeMessage;
use serde_repr::{Deserialize_repr, Serialize_repr};
use std::fmt::Debug;
#[derive(Debug, Message, Clone)]
#[rtype(result = "Result<(), RealtimeError>")]
pub struct Connect {
@ -31,15 +32,40 @@ pub enum BusinessID {
#[derive(Debug, Message, Clone)]
#[rtype(result = "Result<(), RealtimeError>")]
pub struct ClientMessage {
pub struct ClientWebSocketMessage {
pub user: RealtimeUser,
pub message: RealtimeMessage,
}
#[derive(Message)]
#[rtype(result = "Result<(), RealtimeError>")]
pub struct ClientStreamMessage {
pub struct ClientHttpStreamMessage {
pub uid: i64,
pub device_id: String,
pub message: RealtimeMessage,
}
#[derive(Message)]
#[rtype(result = "Result<(), AppError>")]
pub struct ClientHttpUpdateMessage {
pub user: RealtimeUser,
pub workspace_id: String,
pub object_id: String,
/// Encoded yrs::Update or doc state
pub update: Bytes,
/// If the state_vector is not None, it will calculate missing updates base on
/// given state_vector after apply the update
pub state_vector: Option<Bytes>,
pub collab_type: CollabType,
/// If return_tx is Some, calling await on its receiver will wait until the update was applied
/// to the collab. The return value will be None if the input state_vector is None.
pub return_tx: Option<tokio::sync::oneshot::Sender<Result<Option<Vec<u8>>, AppError>>>,
}
#[derive(Message)]
#[rtype(result = "Result<(), AppError>")]
pub struct ClientGenerateEmbeddingMessage {
pub workspace_id: String,
pub object_id: String,
pub return_tx: Option<tokio::sync::oneshot::Sender<Result<(), AppError>>>,
}

View file

@ -1,15 +1,19 @@
use std::ops::Deref;
use actix::{Actor, Context, Handler};
use tracing::{error, info, warn};
use crate::error::RealtimeError;
use crate::CollaborationServer;
use actix::{Actor, Context, Handler};
use anyhow::anyhow;
use app_error::AppError;
use collab_rt_entity::user::UserDevice;
use database::collab::CollabStorage;
use tracing::{error, info, trace, warn};
use crate::actix_ws::client::rt_client::{RealtimeClientWebsocketSinkImpl, RealtimeServer};
use crate::actix_ws::entities::{ClientMessage, ClientStreamMessage, Connect, Disconnect};
use crate::actix_ws::entities::{
ClientGenerateEmbeddingMessage, ClientHttpStreamMessage, ClientHttpUpdateMessage,
ClientWebSocketMessage, Connect, Disconnect,
};
#[derive(Clone)]
pub struct RealtimeServerActor<S>(pub CollaborationServer<S>);
@ -81,15 +85,19 @@ where
}
}
impl<S> Handler<ClientMessage> for RealtimeServerActor<S>
impl<S> Handler<ClientWebSocketMessage> for RealtimeServerActor<S>
where
S: CollabStorage + Unpin,
{
type Result = anyhow::Result<(), RealtimeError>;
fn handle(&mut self, client_msg: ClientMessage, _ctx: &mut Context<Self>) -> Self::Result {
let ClientMessage { user, message } = client_msg;
match message.transform() {
fn handle(
&mut self,
client_msg: ClientWebSocketMessage,
_ctx: &mut Context<Self>,
) -> Self::Result {
let ClientWebSocketMessage { user, message } = client_msg;
match message.split_messages_by_object_id() {
Ok(message_by_object_id) => self.handle_client_message(user, message_by_object_id),
Err(err) => {
if cfg!(debug_assertions) {
@ -101,14 +109,18 @@ where
}
}
impl<S> Handler<ClientStreamMessage> for RealtimeServerActor<S>
impl<S> Handler<ClientHttpStreamMessage> for RealtimeServerActor<S>
where
S: CollabStorage + Unpin,
{
type Result = anyhow::Result<(), RealtimeError>;
fn handle(&mut self, client_msg: ClientStreamMessage, _ctx: &mut Context<Self>) -> Self::Result {
let ClientStreamMessage {
fn handle(
&mut self,
client_msg: ClientHttpStreamMessage,
_ctx: &mut Context<Self>,
) -> Self::Result {
let ClientHttpStreamMessage {
uid,
device_id,
message,
@ -117,7 +129,7 @@ where
// Get the real-time user by the device ID and user ID. If the user is not found, which means
// the user is not connected to the real-time server via websocket.
let user = self.get_user_by_device(&UserDevice::new(&device_id, uid));
match (user, message.transform()) {
match (user, message.split_messages_by_object_id()) {
(Some(user), Ok(messages)) => self.handle_client_message(user, messages),
(None, _) => {
warn!("Can't find the realtime user uid:{}, device:{}. User should connect via websocket before", uid,device_id);
@ -132,3 +144,41 @@ where
}
}
}
impl<S> Handler<ClientHttpUpdateMessage> for RealtimeServerActor<S>
where
S: CollabStorage + Unpin,
{
type Result = Result<(), AppError>;
fn handle(&mut self, msg: ClientHttpUpdateMessage, _ctx: &mut Self::Context) -> Self::Result {
trace!("Receive client http update message");
self
.handle_client_http_update(msg)
.map_err(|err| AppError::Internal(anyhow!("handle client http message error: {}", err)))?;
Ok(())
}
}
impl<S> Handler<ClientGenerateEmbeddingMessage> for RealtimeServerActor<S>
where
S: CollabStorage + Unpin,
{
type Result = Result<(), AppError>;
fn handle(
&mut self,
msg: ClientGenerateEmbeddingMessage,
_ctx: &mut Self::Context,
) -> Self::Result {
self
.handle_client_generate_embedding_request(msg)
.map_err(|err| {
AppError::Internal(anyhow!(
"handle client generate embedding request error: {}",
err
))
})?;
Ok(())
}
}

View file

@ -24,7 +24,7 @@ use collab_rt_entity::{HttpRealtimeMessage, RealtimeMessage};
use shared_entity::response::{AppResponse, AppResponseError};
use crate::actix_ws::client::RealtimeClient;
use crate::actix_ws::entities::ClientStreamMessage;
use crate::actix_ws::entities::ClientHttpStreamMessage;
use crate::actix_ws::server::RealtimeServerActor;
use crate::collab::storage::CollabAccessControlStorage;
use crate::compression::{
@ -101,7 +101,7 @@ async fn post_realtime_message_stream_handler(
req: HttpRequest,
) -> Result<Json<AppResponse<()>>> {
// TODO(nathan): after upgrade the client application, then the device_id should not be empty
let device_id = device_id_from_headers(req.headers()).unwrap_or_else(|_| "".to_string());
let device_id = device_id_from_headers(req.headers()).unwrap_or("");
let uid = state
.user_cache
.get_user_uid(&user_uuid)
@ -117,7 +117,7 @@ async fn post_realtime_message_stream_handler(
let device_id = device_id.to_string();
let message = parser_realtime_msg(bytes.freeze(), req.clone()).await?;
let stream_message = ClientStreamMessage {
let stream_message = ClientHttpStreamMessage {
uid,
device_id,
message,
@ -137,18 +137,29 @@ async fn post_realtime_message_stream_handler(
}
}
fn device_id_from_headers(headers: &HeaderMap) -> std::result::Result<String, AppError> {
headers
.get("device_id")
.ok_or(AppError::InvalidRequest(
"Missing device_id header".to_string(),
))
fn value_from_headers<'a>(
headers: &'a HeaderMap,
keys: &[&str],
missing_msg: &str,
) -> std::result::Result<&'a str, AppError> {
keys
.iter()
.find_map(|key| headers.get(*key))
.ok_or_else(|| AppError::InvalidRequest(missing_msg.to_string()))
.and_then(|header| {
header
.to_str()
.map_err(|err| AppError::InvalidRequest(format!("Failed to parse device_id: {}", err)))
.map_err(|err| AppError::InvalidRequest(format!("Failed to parse header: {}", err)))
})
.map(|s| s.to_string())
}
/// Retrieve device ID from headers
pub fn device_id_from_headers(headers: &HeaderMap) -> Result<&str, AppError> {
value_from_headers(
headers,
&["Device-Id", "device-id", "device_id", "Device-ID"],
"Missing Device-Id or device_id header",
)
}
fn compress_type_from_header_value(

View file

@ -31,7 +31,7 @@ pub struct ClientMessageRouter {
///
/// The message flow:
/// ClientSession(websocket) -> [CollabRealtimeServer] -> [ClientMessageRouter] -> [CollabBroadcast] 1->* websocket(client)
pub(crate) stream_tx: tokio::sync::broadcast::Sender<RealtimeMessage>,
pub(crate) stream_tx: tokio::sync::broadcast::Sender<MessageByObjectId>,
}
impl ClientMessageRouter {
@ -104,60 +104,54 @@ impl ClientMessageRouter {
let (client_msg_rx, rx) = tokio::sync::mpsc::channel(100);
let client_stream = ReceiverStream::new(rx);
tokio::spawn(async move {
while let Some(Ok(realtime_msg)) = stream_rx.next().await {
match realtime_msg.transform() {
Ok(messages_by_oid) => {
for (message_object_id, original_messages) in messages_by_oid {
// if the message is not for the target object, skip it. The stream_rx receives different
// objects' messages, so we need to filter out the messages that are not for the target object.
if target_object_id != message_object_id {
continue;
}
while let Some(Ok(messages_by_oid)) = stream_rx.next().await {
for (message_object_id, original_messages) in messages_by_oid.into_inner() {
// if the message is not for the target object, skip it. The stream_rx receives different
// objects' messages, so we need to filter out the messages that are not for the target object.
if target_object_id != message_object_id {
continue;
}
// before applying user messages, we need to check if the user has the permission
// valid_messages contains the messages that the user is allowed to apply
// invalid_message contains the messages that the user is not allowed to apply
let (valid_messages, invalid_message) = Self::access_control(
&stream_workspace_id,
&user.uid,
&message_object_id,
access_control.clone(),
original_messages,
)
.await;
trace!(
"{} receive client:{}, device:{}, message: valid:{} invalid:{}",
message_object_id,
user.uid,
user.device_id,
valid_messages.len(),
invalid_message.len()
);
// before applying user messages, we need to check if the user has the permission
// valid_messages contains the messages that the user is allowed to apply
// invalid_message contains the messages that the user is not allowed to apply
let (valid_messages, invalid_message) = Self::access_control(
&stream_workspace_id,
&user.uid,
&message_object_id,
access_control.clone(),
original_messages,
)
.await;
trace!(
"{} receive client:{}, device:{}, message: valid:{} invalid:{}",
message_object_id,
user.uid,
user.device_id,
valid_messages.len(),
invalid_message.len()
);
if valid_messages.is_empty() {
continue;
}
if valid_messages.is_empty() {
continue;
}
// if tx.send return error, it means the client is disconnected from the group
if let Err(err) = client_msg_rx
.send([(message_object_id, valid_messages)].into())
.await
{
trace!(
"{} send message to user:{} stream fail with error: {}, break the loop",
target_object_id,
user.user_device(),
err,
);
return;
}
}
},
Err(err) => {
if cfg!(debug_assertions) {
error!("parse client message error: {}", err);
}
},
// if tx.send return error, it means the client is disconnected from the group
if let Err(err) = client_msg_rx
.send(MessageByObjectId::new_with_message(
message_object_id,
valid_messages,
))
.await
{
trace!(
"{} send message to user:{} stream fail with error: {}, break the loop",
target_object_id,
user.user_device(),
err,
);
return;
}
}
}
});

View file

@ -1,10 +1,20 @@
#![allow(unused_imports)]
use crate::command::{CLCommandSender, CollaborationCommand};
use anyhow::{anyhow, Context};
use app_error::AppError;
use async_trait::async_trait;
use collab::entity::EncodedCollab;
use collab_entity::CollabType;
use collab_rt_entity::ClientCollabMessage;
use database::collab::{
insert_into_af_collab_bulk_for_user, AppResult, CollabMetadata, CollabStorage,
CollabStorageAccessControl, GetCollabOrigin,
};
use database_entity::dto::{
AFAccessLevel, AFSnapshotMeta, AFSnapshotMetas, CollabParams, InsertSnapshotParams,
PendingCollabWrite, QueryCollab, QueryCollabParams, QueryCollabResult, SnapshotData,
};
use itertools::{Either, Itertools};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use sqlx::Transaction;
@ -18,17 +28,7 @@ use tracing::warn;
use tracing::{error, instrument, trace};
use uuid::Uuid;
use validator::Validate;
use crate::command::{CLCommandSender, CollaborationCommand};
use app_error::AppError;
use database::collab::{
insert_into_af_collab_bulk_for_user, AppResult, CollabMetadata, CollabStorage,
CollabStorageAccessControl, GetCollabOrigin,
};
use database_entity::dto::{
AFAccessLevel, AFSnapshotMeta, AFSnapshotMetas, CollabParams, InsertSnapshotParams,
PendingCollabWrite, QueryCollab, QueryCollabParams, QueryCollabResult, SnapshotData,
};
use yrs::Update;
use crate::collab::access_control::CollabStorageAccessControlImpl;
use crate::collab::cache::CollabCache;
@ -131,7 +131,6 @@ where
.await?;
Ok(())
}
async fn get_encode_collab_from_editing(&self, oid: &str) -> Option<EncodedCollab> {
let object_id = oid.to_string();
let (ret, rx) = tokio::sync::oneshot::channel();
@ -244,6 +243,45 @@ where
.bulk_insert_collab(workspace_id, uid, params_list)
.await
}
/// Sends a collab message to all connected clients.
/// # Arguments
/// * `object_id` - The ID of the collaboration object.
/// * `collab_messages` - The list of collab messages to broadcast.
pub async fn broadcast_encode_collab(
&self,
object_id: String,
collab_messages: Vec<ClientCollabMessage>,
) -> Result<(), AppError> {
let (sender, recv) = tokio::sync::oneshot::channel();
self
.rt_cmd_sender
.send(CollaborationCommand::ServerSendCollabMessage {
object_id,
collab_messages,
ret: sender,
})
.await
.map_err(|err| {
AppError::Unhandled(format!(
"Failed to send encode collab command to realtime server: {}",
err
))
})?;
match recv.await {
Ok(res) =>
if let Err(err) = res {
error!("Failed to broadcast encode collab: {}", err);
}
,
// caller may have dropped the receiver
Err(err) => warn!("Failed to receive response from realtime server: {}", err),
}
Ok(())
}
}
#[async_trait]
@ -256,7 +294,7 @@ where
workspace_id: &str,
uid: &i64,
params: CollabParams,
write_immediately: bool,
flush_to_disk: bool,
) -> AppResult<()> {
params.validate()?;
let is_exist = self.cache.is_exist(workspace_id, &params.object_id).await?;
@ -280,7 +318,7 @@ where
.update_policy(uid, &params.object_id, AFAccessLevel::FullAccess)
.await?;
}
if write_immediately {
if flush_to_disk {
self.insert_collab(workspace_id, uid, params).await?;
} else {
self.queue_insert_collab(workspace_id, uid, params).await?;
@ -397,41 +435,6 @@ where
Ok(encode_collab)
}
async fn broadcast_encode_collab(
&self,
object_id: String,
collab_messages: Vec<ClientCollabMessage>,
) -> Result<(), AppError> {
let (sender, recv) = tokio::sync::oneshot::channel();
self
.rt_cmd_sender
.send(CollaborationCommand::ServerSendCollabMessage {
object_id,
collab_messages,
ret: sender,
})
.await
.map_err(|err| {
AppError::Unhandled(format!(
"Failed to send encode collab command to realtime server: {}",
err
))
})?;
match recv.await {
Ok(res) =>
if let Err(err) = res {
error!("Failed to broadcast encode collab: {}", err);
}
,
// caller may have dropped the receiver
Err(err) => warn!("Failed to receive response from realtime server: {}", err),
}
Ok(())
}
async fn batch_get_collab(
&self,
_uid: &i64,

View file

@ -15,7 +15,6 @@ use std::{
sync::{Arc, Weak},
};
use tracing::error;
pub type CLCommandSender = tokio::sync::mpsc::Sender<CollaborationCommand>;
pub type CLCommandReceiver = tokio::sync::mpsc::Receiver<CollaborationCommand>;

View file

@ -43,10 +43,8 @@ impl ConnectState {
let old_user = e.insert(new_user.clone());
trace!("[realtime]: new connection replaces old => {}", new_user);
if let Some((_, old_stream)) = self.client_message_routers.remove(&old_user) {
info!(
"Removing old stream for same user and device: {}",
old_user.uid
);
info!("Removing old stream for same user and device: {}", old_user);
old_stream
.sink
.do_send(RealtimeMessage::System(SystemMessage::DuplicateConnection));

View file

@ -258,7 +258,7 @@ impl CollabBroadcast {
async fn handle_client_messages<Sink>(
object_id: &str,
message_map: MessageByObjectId,
message_by_object_id: MessageByObjectId,
sink: &mut Sink,
collab: Arc<RwLock<dyn BorrowMut<Collab> + Send + Sync + 'static>>,
metrics_calculate: &Arc<CollabRealtimeMetrics>,
@ -267,7 +267,7 @@ async fn handle_client_messages<Sink>(
Sink: SinkExt<CollabMessage> + Unpin + 'static,
<Sink as futures_util::Sink<CollabMessage>>::Error: std::error::Error,
{
for (message_object_id, collab_messages) in message_map {
for (message_object_id, collab_messages) in message_by_object_id.into_inner() {
// Ignore messages where the object_id does not match. This situation should not occur, as
// [ClientMessageRouter::init_client_communication] is expected to filter out such messages. However,
// as a precautionary measure, we perform this check to handle any unexpected cases.
@ -328,9 +328,6 @@ async fn handle_one_client_message(
// If the payload is empty, we don't need to apply any updates .
// Currently, only the ping message should has an empty payload.
if collab_msg.payload().is_empty() {
if !matches!(collab_msg, ClientCollabMessage::ClientCollabStateCheck(_)) {
error!("receive unexpected empty payload message:{}", collab_msg);
}
return Ok(CollabAck::new(
message_origin,
object_id.to_string(),

View file

@ -1,21 +1,25 @@
use std::collections::HashMap;
use std::sync::Arc;
use async_stream::stream;
use collab::core::origin::CollabOrigin;
use collab::entity::EncodedCollab;
use dashmap::DashMap;
use futures_util::StreamExt;
use tracing::{instrument, trace, warn};
use collab_rt_entity::user::RealtimeUser;
use collab_rt_entity::{AckCode, ClientCollabMessage, ServerCollabMessage, SinkMessage};
use collab_rt_entity::{CollabAck, RealtimeMessage};
use database::collab::CollabStorage;
use crate::client::client_msg_router::ClientMessageRouter;
use crate::error::RealtimeError;
use crate::group::manager::GroupManager;
use crate::group::null_sender::NullSender;
use async_stream::stream;
use bytes::Bytes;
use collab::core::origin::{CollabClient, CollabOrigin};
use collab::entity::EncodedCollab;
use collab_entity::CollabType;
use collab_rt_entity::user::RealtimeUser;
use collab_rt_entity::CollabAck;
use collab_rt_entity::{
AckCode, ClientCollabMessage, MessageByObjectId, ServerCollabMessage, SinkMessage, UpdateSync,
};
use collab_rt_protocol::{Message, SyncMessage};
use dashmap::DashMap;
use database::collab::CollabStorage;
use futures_util::StreamExt;
use std::sync::Arc;
use tracing::{error, instrument, trace, warn};
use yrs::updates::encoder::Encode;
use yrs::StateVector;
/// Using [GroupCommand] to interact with the group
/// - HandleClientCollabMessage: Handle the client message
@ -28,6 +32,14 @@ pub enum GroupCommand {
collab_messages: Vec<ClientCollabMessage>,
ret: tokio::sync::oneshot::Sender<Result<(), RealtimeError>>,
},
HandleClientHttpUpdate {
user: RealtimeUser,
workspace_id: String,
object_id: String,
update: Bytes,
collab_type: CollabType,
ret: tokio::sync::oneshot::Sender<Result<(), RealtimeError>>,
},
EncodeCollab {
object_id: String,
ret: tokio::sync::oneshot::Sender<Option<EncodedCollab>>,
@ -37,6 +49,15 @@ pub enum GroupCommand {
collab_messages: Vec<ClientCollabMessage>,
ret: tokio::sync::oneshot::Sender<Result<(), RealtimeError>>,
},
GenerateCollabEmbedding {
object_id: String,
ret: tokio::sync::oneshot::Sender<Result<(), RealtimeError>>,
},
CalculateMissingUpdate {
object_id: String,
state_vector: StateVector,
ret: tokio::sync::oneshot::Sender<Result<Vec<u8>, RealtimeError>>,
},
}
pub type GroupCommandSender = tokio::sync::mpsc::Sender<GroupCommand>;
@ -104,6 +125,48 @@ where
warn!("Send handle server collab message result fail: {:?}", err);
}
},
GroupCommand::HandleClientHttpUpdate {
user,
workspace_id,
object_id,
update,
collab_type,
ret,
} => {
let result = self
.handle_client_posted_http_update(
&user,
&workspace_id,
&object_id,
collab_type,
update,
)
.await;
if let Err(err) = ret.send(result) {
warn!("Send handle client update message result fail: {:?}", err);
}
},
GroupCommand::GenerateCollabEmbedding { object_id, ret } => {
// TODO(nathan): generate embedding
trace!("Generate embedding for group:{}", object_id);
let _ = ret.send(Ok(()));
},
GroupCommand::CalculateMissingUpdate {
object_id,
state_vector,
ret,
} => {
let group = self.group_manager.get_group(&object_id).await;
match group {
None => {
let _ = ret.send(Err(RealtimeError::GroupNotFound(object_id.clone())));
},
Some(group) => {
let result = group.calculate_missing_update(state_vector).await;
let _ = ret.send(result);
},
}
},
}
})
.await;
@ -120,7 +183,6 @@ where
/// 2.2 For non-'init sync' messages:
/// - If the group exists: The message is sent to the group for synchronization as per [CollabSyncProtocol].
/// - If the group does not exist: The client is prompted to send an 'init sync' message first.
#[instrument(level = "trace", skip_all)]
async fn handle_client_collab_message(
&self,
@ -151,7 +213,9 @@ where
if !is_user_subscribed {
// safety: messages is not empty because we have checked it before
let first_message = messages.first().unwrap();
self.subscribe_group(user, first_message).await?;
self
.subscribe_group_with_message(user, first_message)
.await?;
}
forward_message_to_group(user, object_id, messages, &self.msg_router_by_user).await;
} else {
@ -159,8 +223,10 @@ where
// If there is no existing group for the given object_id and the message is an 'init message',
// then create a new group and add the user as a subscriber to this group.
if first_message.is_client_init_sync() {
self.create_group(user, first_message).await?;
self.subscribe_group(user, first_message).await?;
self.create_group_with_message(user, first_message).await?;
self
.subscribe_group_with_message(user, first_message)
.await?;
forward_message_to_group(user, object_id, messages, &self.msg_router_by_user).await;
} else if let Some(entry) = self.msg_router_by_user.get(user) {
warn!(
@ -185,6 +251,69 @@ where
Ok(())
}
/// This functions will be called when client post update via http requset
#[instrument(level = "trace", skip_all)]
async fn handle_client_posted_http_update(
&self,
user: &RealtimeUser,
workspace_id: &str,
object_id: &str,
collab_type: collab_entity::CollabType,
update: Bytes,
) -> Result<(), RealtimeError> {
let origin = CollabOrigin::Client(CollabClient {
uid: user.uid,
device_id: user.device_id.clone(),
});
// Create message router for user if it's not exist
let should_sub = self.msg_router_by_user.get(user).is_none();
if should_sub {
trace!("create a new client message router for user:{}", user);
let new_client_router = ClientMessageRouter::new(NullSender::<()>::default());
self
.msg_router_by_user
.insert(user.clone(), new_client_router);
}
// Create group if it's not exist
let is_group_exist = self.group_manager.contains_group(object_id);
if !is_group_exist {
trace!("The group:{} is not found, create a new group", object_id);
self
.create_group(user, workspace_id, object_id, collab_type)
.await?;
}
// Only subscribe when the user is not subscribed to the group
if should_sub {
self.subscribe_group(user, object_id, &origin).await?;
}
if let Some(client_stream) = self.msg_router_by_user.get(user) {
let payload = Message::Sync(SyncMessage::Update(update.to_vec())).encode_v1();
let msg = ClientCollabMessage::ClientUpdateSync {
data: UpdateSync {
origin,
object_id: object_id.to_string(),
msg_id: chrono::Utc::now().timestamp_millis() as u64,
payload: payload.into(),
},
};
let message = MessageByObjectId::new_with_message(object_id.to_string(), vec![msg]);
let err = client_stream.stream_tx.send(message);
if let Err(err) = err {
warn!("Send user:{} http update message to group:{}", user, err);
self.msg_router_by_user.remove(user);
}
} else {
warn!(
"The client stream: {} is not found when applying client update",
user
);
}
Ok(())
}
/// similar to `handle_client_collab_message`, but the messages are sent from the server instead.
#[instrument(level = "trace", skip_all)]
async fn handle_server_collab_messages(
@ -214,12 +343,11 @@ where
collab_message_sender,
message_by_oid_receiver,
);
let message = HashMap::from([(object_id.clone(), messages)]);
let message = MessageByObjectId::new_with_message(object_id.clone(), messages);
if let Err(err) = message_by_oid_sender.try_send(message) {
tracing::error!(
error!(
"failed to send message to group: {}, object_id: {}",
err,
object_id
err, object_id
);
}
};
@ -227,13 +355,22 @@ where
Ok(())
}
async fn subscribe_group(
async fn subscribe_group_with_message(
&self,
user: &RealtimeUser,
collab_message: &ClientCollabMessage,
) -> Result<(), RealtimeError> {
let object_id = collab_message.object_id();
let message_origin = collab_message.origin();
self.subscribe_group(user, object_id, message_origin).await
}
async fn subscribe_group(
&self,
user: &RealtimeUser,
object_id: &str,
collab_origin: &CollabOrigin,
) -> Result<(), RealtimeError> {
match self.msg_router_by_user.get_mut(user) {
None => {
warn!("The client stream: {} is not found", user);
@ -245,7 +382,7 @@ where
.subscribe_group(
user,
object_id,
message_origin,
collab_origin,
client_msg_router.value_mut(),
)
.await
@ -254,7 +391,7 @@ where
}
#[instrument(level = "debug", skip_all)]
async fn create_group(
async fn create_group_with_message(
&self,
user: &RealtimeUser,
collab_message: &ClientCollabMessage,
@ -263,7 +400,6 @@ where
match collab_message {
ClientCollabMessage::ClientInitSync { data, .. } => {
self
.group_manager
.create_group(
user,
&data.workspace_id,
@ -271,12 +407,27 @@ where
data.collab_type.clone(),
)
.await?;
Ok(())
},
_ => Err(RealtimeError::ExpectInitSync(collab_message.to_string())),
}
}
#[instrument(level = "debug", skip_all)]
async fn create_group(
&self,
user: &RealtimeUser,
workspace_id: &str,
object_id: &str,
collab_type: collab_entity::CollabType,
) -> Result<(), RealtimeError> {
self
.group_manager
.create_group(user, workspace_id, object_id, collab_type)
.await?;
Ok(())
}
}
/// Forward the message to the group.
@ -299,10 +450,8 @@ pub async fn forward_message_to_group(
.map(|v| v.msg_id())
.collect::<Vec<_>>()
);
let pair = (object_id, collab_messages);
let err = client_stream
.stream_tx
.send(RealtimeMessage::ClientCollabV2([pair].into()));
let message = MessageByObjectId::new_with_message(object_id, collab_messages);
let err = client_stream.stream_tx.send(message);
if let Err(err) = err {
warn!("Send user:{} message to group:{}", user.uid, err);
client_msg_router.remove(user);

View file

@ -8,14 +8,14 @@ use collab::entity::EncodedCollab;
use collab::lock::RwLock;
use collab::preclude::Collab;
use collab_entity::CollabType;
use collab_rt_entity::user::RealtimeUser;
use collab_rt_entity::CollabMessage;
use collab_rt_entity::MessageByObjectId;
use dashmap::DashMap;
use futures_util::{SinkExt, StreamExt};
use tokio_util::sync::CancellationToken;
use tracing::{event, info, trace};
use collab_rt_entity::user::RealtimeUser;
use collab_rt_entity::CollabMessage;
use collab_rt_entity::MessageByObjectId;
use yrs::{ReadTxn, StateVector};
use collab_stream::error::StreamError;
@ -78,7 +78,7 @@ impl CollabGroup {
uid,
storage,
edit_state.clone(),
Arc::downgrade(&collab),
collab.clone(),
collab_type.clone(),
persistence_interval,
indexer,
@ -99,6 +99,19 @@ impl CollabGroup {
})
}
pub async fn calculate_missing_update(
&self,
state_vector: StateVector,
) -> Result<Vec<u8>, RealtimeError> {
let guard = self.collab.read().await;
let txn = guard.transact();
let update = txn.encode_state_as_update_v1(&state_vector);
drop(txn);
drop(guard);
Ok(update)
}
pub async fn encode_collab(&self) -> Result<EncodedCollab, RealtimeError> {
let lock = self.collab.read().await;
let encode_collab = lock.encode_collab_v1(|collab| {

View file

@ -61,7 +61,7 @@ where
}
pub fn get_inactive_groups(&self) -> Vec<String> {
self.state.get_inactive_group_ids()
self.state.remove_inactive_groups()
}
pub fn contains_user(&self, object_id: &str, user: &RealtimeUser) -> bool {
@ -80,11 +80,6 @@ where
self.state.get_group(object_id).await
}
#[instrument(skip(self))]
fn remove_group(&self, object_id: &str) {
self.state.remove_group(object_id);
}
pub async fn subscribe_group(
&self,
user: &RealtimeUser,
@ -154,7 +149,7 @@ where
collab_type
);
let mut indexer = self.indexer_provider.indexer_for(collab_type.clone());
let mut indexer = self.indexer_provider.indexer_for(&collab_type);
if indexer.is_some()
&& !self
.indexer_provider

View file

@ -2,6 +2,7 @@ pub(crate) mod broadcast;
pub(crate) mod cmd;
pub(crate) mod group_init;
pub(crate) mod manager;
mod null_sender;
mod persistence;
mod plugin;
pub(crate) mod protocol;

View file

@ -0,0 +1,53 @@
use crate::error::RealtimeError;
use crate::RealtimeClientWebsocketSink;
use collab_rt_entity::RealtimeMessage;
use futures::Sink;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
/// Futures [Sink] compatible sender, that always throws the input away.
/// Essentially: a `/dev/null` equivalent.
#[derive(Clone)]
pub(crate) struct NullSender<T> {
_marker: PhantomData<T>,
}
impl<T> Default for NullSender<T> {
fn default() -> Self {
NullSender {
_marker: PhantomData,
}
}
}
impl<T> Sink<T> for NullSender<T> {
type Error = RealtimeError;
#[inline]
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
#[inline]
fn start_send(self: Pin<&mut Self>, _: T) -> Result<(), Self::Error> {
Ok(())
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
#[inline]
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
}
impl<T> RealtimeClientWebsocketSink for NullSender<T>
where
T: Send + Sync + 'static,
{
fn do_send(&self, _message: RealtimeMessage) {}
}

View file

@ -1,4 +1,4 @@
use std::sync::{Arc, Weak};
use std::sync::Arc;
use std::time::Duration;
use anyhow::anyhow;
@ -22,7 +22,9 @@ pub(crate) struct GroupPersistence<S> {
storage: Arc<S>,
uid: i64,
edit_state: Arc<EditState>,
collab: Weak<RwLock<Collab>>,
/// Use Arc<RwLock<Collab>> instead of Weak<RwLock<Collab>> to make sure the collab is not dropped
/// when saving collab data to disk
collab: Arc<RwLock<Collab>>,
collab_type: CollabType,
persistence_interval: Duration,
indexer: Option<Arc<dyn Indexer>>,
@ -40,7 +42,7 @@ where
uid: i64,
storage: Arc<S>,
edit_state: Arc<EditState>,
collab: Weak<RwLock<Collab>>,
collab: Arc<RwLock<Collab>>,
collab_type: CollabType,
persistence_interval: Duration,
ai_client: Option<Arc<dyn Indexer>>,
@ -65,7 +67,6 @@ where
loop {
// delay 30 seconds before the first save. We don't want to save immediately after the collab is created
tokio::time::sleep(Duration::from_secs(30)).await;
tokio::select! {
_ = interval.tick() => {
if self.attempt_save().await.is_err() {
@ -117,17 +118,13 @@ where
Ok(())
}
async fn save(&self, write_immediately: bool) -> Result<(), AppError> {
async fn save(&self, flush_to_disk: bool) -> Result<(), AppError> {
let object_id = self.object_id.clone();
let workspace_id = self.workspace_id.clone();
let collab_type = self.collab_type.clone();
let collab = match self.collab.upgrade() {
Some(collab) => collab,
None => return Err(AppError::Internal(anyhow!("collab has been dropped"))),
};
let params = {
let cloned_collab = collab.clone();
let cloned_collab = self.collab.clone();
let (workspace_id, mut params, object_id) = tokio::task::spawn_blocking(move || {
let collab = cloned_collab.blocking_read();
let params = get_encode_collab(&workspace_id, &object_id, &collab, &collab_type)?;
@ -135,7 +132,7 @@ where
})
.await??;
let lock = collab.read().await;
let lock = self.collab.read().await;
if let Some(indexer) = &self.indexer {
match indexer.embedding_params(&lock).await {
Ok(embedding_params) => {
@ -165,7 +162,7 @@ where
self
.storage
.queue_insert_or_update_collab(&self.workspace_id, &self.uid, params, write_immediately)
.queue_insert_or_update_collab(&self.workspace_id, &self.uid, params, flush_to_disk)
.await?;
// Update the edit state on successful save
self.edit_state.tick();

View file

@ -36,9 +36,8 @@ impl GroupManagementState {
}
}
/// Performs a periodic check to remove groups based on the following conditions:
/// Groups that have been inactive for a specified period of time.
pub fn get_inactive_group_ids(&self) -> Vec<String> {
/// Returns group ids of inactive groups.
pub fn remove_inactive_groups(&self) -> Vec<String> {
let mut inactive_group_ids = vec![];
for entry in self.group_by_object_id.iter() {
let (object_id, group) = (entry.key(), entry.value());

View file

@ -1,45 +1,38 @@
use std::sync::Arc;
use anyhow::anyhow;
use async_trait::async_trait;
use collab::preclude::Collab;
use crate::indexer::{DocumentDataExt, Indexer};
use app_error::AppError;
use appflowy_ai_client::client::AppFlowyAIClient;
use appflowy_ai_client::dto::{
EmbeddingEncodingFormat, EmbeddingInput, EmbeddingModel, EmbeddingOutput, EmbeddingRequest,
};
use async_trait::async_trait;
use collab::preclude::Collab;
use collab_document::document::DocumentBody;
use collab_document::error::DocumentError;
use collab_entity::CollabType;
use database_entity::dto::{AFCollabEmbeddingParams, AFCollabEmbeddings, EmbeddingContentType};
use std::sync::Arc;
use crate::config::get_env_var;
use crate::indexer::open_ai::{split_text_by_max_content_len, split_text_by_max_tokens};
use crate::indexer::open_ai::split_text_by_max_content_len;
use crate::indexer::Indexer;
use tiktoken_rs::CoreBPE;
use tracing::trace;
use uuid::Uuid;
pub struct DocumentIndexer {
ai_client: AppFlowyAIClient,
#[allow(dead_code)]
tokenizer: Arc<CoreBPE>,
embedding_model: EmbeddingModel,
use_tiktoken: bool,
}
impl DocumentIndexer {
pub fn new(ai_client: AppFlowyAIClient) -> Arc<Self> {
let tokenizer = tiktoken_rs::cl100k_base().unwrap();
let use_tiktoken = get_env_var("APPFLOWY_AI_CONTENT_SPLITTER_TIKTOKEN", "false")
.parse::<bool>()
.unwrap_or(false);
Arc::new(Self {
ai_client,
tokenizer: Arc::new(tokenizer),
embedding_model: EmbeddingModel::TextEmbedding3Small,
use_tiktoken,
})
}
}
@ -58,17 +51,14 @@ impl Indexer for DocumentIndexer {
)
})?;
let result = document.get_document_data(&collab.transact());
let result = document.to_plain_text(collab.transact(), false);
match result {
Ok(document_data) => {
let content = document_data.to_plain_text();
Ok(content) => {
create_embedding(
object_id,
content,
CollabType::Document,
&self.embedding_model,
self.tokenizer.clone(),
self.use_tiktoken,
)
.await
},
@ -82,6 +72,15 @@ impl Indexer for DocumentIndexer {
}
}
async fn embedding_text(
&self,
object_id: String,
content: String,
collab_type: CollabType,
) -> Result<Vec<AFCollabEmbeddingParams>, AppError> {
create_embedding(object_id, content, collab_type, &self.embedding_model).await
}
async fn embeddings(
&self,
mut params: Vec<AFCollabEmbeddingParams>,
@ -142,29 +141,14 @@ async fn create_embedding(
content: String,
collab_type: CollabType,
embedding_model: &EmbeddingModel,
tokenizer: Arc<CoreBPE>,
use_tiktoken: bool,
) -> Result<Vec<AFCollabEmbeddingParams>, AppError> {
let split_contents = if use_tiktoken {
let max_tokens = embedding_model.default_dimensions() as usize;
if content.len() < 500 {
split_text_by_max_tokens(content, max_tokens, tokenizer.as_ref())?
} else {
tokio::task::spawn_blocking(move || {
split_text_by_max_tokens(content, max_tokens, tokenizer.as_ref())
})
.await??
}
} else {
debug_assert!(matches!(
embedding_model,
EmbeddingModel::TextEmbedding3Small
));
// We assume that every token is ~4 bytes. We're going to split document content into fragments
// of ~2000 tokens each.
split_text_by_max_content_len(content, 8000)?
};
debug_assert!(matches!(
embedding_model,
EmbeddingModel::TextEmbedding3Small
));
// We assume that every token is ~4 bytes. We're going to split document content into fragments
// of ~2000 tokens each.
let split_contents = split_text_by_max_content_len(content, 8000)?;
Ok(
split_contents
.into_iter()

View file

@ -1,72 +0,0 @@
use collab_document::blocks::{DocumentData, TextDelta};
pub trait DocumentDataExt {
fn to_plain_text(&self) -> String;
}
impl DocumentDataExt for DocumentData {
fn to_plain_text(&self) -> String {
let mut buf = String::new();
if let Some(text_map) = self.meta.text_map.as_ref() {
let mut stack = Vec::new();
stack.push(&self.page_id);
// do a depth-first scan of the document blocks
while let Some(block_id) = stack.pop() {
if let Some(block) = self.blocks.get(block_id) {
if let Some(deltas) = get_delta_from_block_data(block) {
push_deltas_to_str(&mut buf, deltas);
} else if let Some(deltas) = get_delta_from_external_text_id(block, text_map) {
push_deltas_to_str(&mut buf, deltas);
}
if let Some(children) = self.meta.children_map.get(&block.children) {
// we want to process children blocks in the same order they are given in children_map
// however stack.pop gives us the last element first, so we push children
// in reverse order
stack.extend(children.iter().rev());
}
}
}
}
//tracing::trace!("Document plain text: `{}`", buf);
buf
}
}
/// Try to retrieve deltas from `block.data.delta`.
fn get_delta_from_block_data(block: &collab_document::blocks::Block) -> Option<Vec<TextDelta>> {
if let Some(delta) = block.data.get("delta") {
if let Ok(deltas) = serde_json::from_value::<Vec<TextDelta>>(delta.clone()) {
return Some(deltas);
}
}
None
}
/// Try to retrieve deltas from text_map's text associated with `block.external_id`.
fn get_delta_from_external_text_id(
block: &collab_document::blocks::Block,
text_map: &std::collections::HashMap<String, String>,
) -> Option<Vec<TextDelta>> {
if block.external_type.as_deref() == Some("text") {
if let Some(text_id) = block.external_id.as_deref() {
if let Some(json) = text_map.get(text_id) {
if let Ok(deltas) = serde_json::from_str::<Vec<TextDelta>>(json) {
return Some(deltas);
}
}
}
}
None
}
fn push_deltas_to_str(buf: &mut String, deltas: Vec<TextDelta>) {
for delta in deltas {
if let TextDelta::Inserted(text, _) = delta {
let trimmed = text.trim();
if !trimmed.is_empty() {
buf.push_str(trimmed);
buf.push(' ');
}
}
}
}

View file

@ -1,8 +1,6 @@
mod document_indexer;
mod ext;
mod open_ai;
mod provider;
pub use document_indexer::DocumentIndexer;
pub use ext::DocumentDataExt;
pub use provider::*;

View file

@ -30,6 +30,7 @@ use unicode_segmentation::UnicodeSegmentation;
/// https://tokio.rs/blog/2020-04-preemption
/// https://ryhl.io/blog/async-what-is-blocking/
#[inline]
#[allow(dead_code)]
pub fn split_text_by_max_tokens(
content: String,
max_tokens: usize,

View file

@ -31,6 +31,13 @@ pub trait Indexer: Send + Sync {
collab: &Collab,
) -> Result<Vec<AFCollabEmbeddingParams>, AppError>;
async fn embedding_text(
&self,
object_id: String,
content: String,
collab_type: CollabType,
) -> Result<Vec<AFCollabEmbeddingParams>, AppError>;
async fn embeddings(
&self,
params: Vec<AFCollabEmbeddingParams>,
@ -90,8 +97,8 @@ impl IndexerProvider {
/// Returns indexer for a specific type of [Collab] object.
/// If collab of given type is not supported or workspace it belongs to has indexing disabled,
/// returns `None`.
pub fn indexer_for(&self, collab_type: CollabType) -> Option<Arc<dyn Indexer>> {
self.indexer_cache.get(&collab_type).cloned()
pub fn indexer_for(&self, collab_type: &CollabType) -> Option<Arc<dyn Indexer>> {
self.indexer_cache.get(collab_type).cloned()
}
fn get_unindexed_collabs(
@ -188,7 +195,7 @@ impl IndexerProvider {
let collab_type = params.collab_type.clone();
let data = params.encoded_collab_v1.clone();
if let Some(indexer) = self.indexer_for(collab_type) {
if let Some(indexer) = self.indexer_for(&collab_type) {
let encoded_collab = tokio::task::spawn_blocking(move || {
let encode_collab = EncodedCollab::decode_from_bytes(&data)?;
Ok::<_, AppError>(encode_collab)

View file

@ -1,15 +1,19 @@
use std::sync::{Arc, Weak};
use std::time::Duration;
use anyhow::Result;
use dashmap::mapref::entry::Entry;
use dashmap::DashMap;
use tokio::time::interval;
use tracing::{error, info, trace};
use access_control::collab::RealtimeAccessControl;
use anyhow::{anyhow, Result};
use app_error::AppError;
use collab_rt_entity::user::{RealtimeUser, UserDevice};
use collab_rt_entity::MessageByObjectId;
use dashmap::mapref::entry::Entry;
use dashmap::DashMap;
use tokio::sync::mpsc::Sender;
use tokio::task::yield_now;
use tokio::time::interval;
use tracing::{error, info, trace, warn};
use yrs::updates::decoder::Decode;
use yrs::StateVector;
use database::collab::CollabStorage;
@ -23,6 +27,7 @@ use crate::group::manager::GroupManager;
use crate::indexer::IndexerProvider;
use crate::rt_server::collaboration_runtime::COLLAB_RUNTIME;
use crate::actix_ws::entities::{ClientGenerateEmbeddingMessage, ClientHttpUpdateMessage};
use crate::{CollabRealtimeMetrics, RealtimeClientWebsocketSink};
#[derive(Clone)]
@ -151,41 +156,8 @@ where
user: RealtimeUser,
message_by_oid: MessageByObjectId,
) -> Result<(), RealtimeError> {
let group_sender_by_object_id = self.group_sender_by_object_id.clone();
let client_msg_router_by_user = self.connect_state.client_message_routers.clone();
let group_manager = self.group_manager.clone();
let enable_custom_runtime = self.enable_custom_runtime;
for (object_id, collab_messages) in message_by_oid {
let old_sender = group_sender_by_object_id
.get(&object_id)
.map(|entry| entry.value().clone());
let sender = match old_sender {
Some(sender) => sender,
None => match group_sender_by_object_id.entry(object_id.clone()) {
Entry::Occupied(entry) => entry.get().clone(),
Entry::Vacant(entry) => {
let (new_sender, recv) = tokio::sync::mpsc::channel(2000);
let runner = GroupCommandRunner {
group_manager: group_manager.clone(),
msg_router_by_user: client_msg_router_by_user.clone(),
recv: Some(recv),
};
let object_id = entry.key().clone();
if enable_custom_runtime {
COLLAB_RUNTIME.spawn(runner.run(object_id));
} else {
tokio::spawn(runner.run(object_id));
}
entry.insert(new_sender.clone());
new_sender
},
},
};
for (object_id, collab_messages) in message_by_oid.into_inner() {
let group_cmd_sender = self.create_group_if_not_exist(&object_id);
let cloned_user = user.clone();
// Create a new task to send a message to the group command runner without waiting for the
// result. This approach is used to prevent potential issues with the actor's mailbox in
@ -193,7 +165,7 @@ where
// immediately proceed to process the next message.
tokio::spawn(async move {
let (tx, rx) = tokio::sync::oneshot::channel();
match sender
match group_cmd_sender
.send(GroupCommand::HandleClientCollabMessage {
user: cloned_user,
object_id,
@ -226,6 +198,204 @@ where
Ok(())
}
#[inline]
pub fn handle_client_http_update(
&self,
message: ClientHttpUpdateMessage,
) -> Result<(), RealtimeError> {
let group_cmd_sender = self.create_group_if_not_exist(&message.object_id);
tokio::spawn(async move {
let object_id = message.object_id.clone();
let (tx, rx) = tokio::sync::oneshot::channel();
let result = group_cmd_sender
.send(GroupCommand::HandleClientHttpUpdate {
user: message.user,
workspace_id: message.workspace_id,
object_id: message.object_id,
update: message.update,
collab_type: message.collab_type,
ret: tx,
})
.await;
let return_tx = message.return_tx;
if let Err(err) = result {
if let Some(return_rx) = return_tx {
let _ = return_rx.send(Err(AppError::Internal(anyhow!(
"send update to group fail: {}",
err
))));
return;
} else {
error!("send http update to group fail: {}", err);
}
}
match rx.await {
Ok(Ok(())) => {
if message.state_vector.is_some() && return_tx.is_none() {
warn!(
"state_vector is not None, but return_tx is None, object_id: {}",
object_id
);
}
if let Some(return_rx) = return_tx {
if let Some(state_vector) = message
.state_vector
.and_then(|data| StateVector::decode_v1(&data).ok())
{
// yield
yield_now().await;
// Calculate missing update
let (tx, rx) = tokio::sync::oneshot::channel();
let _ = group_cmd_sender
.send(GroupCommand::CalculateMissingUpdate {
object_id,
state_vector,
ret: tx,
})
.await;
match rx.await {
Ok(missing_update_result) => {
let result = missing_update_result
.map_err(|err| {
AppError::Internal(anyhow!("fail to calculate missing update: {}", err))
})
.map(Some);
let _ = return_rx.send(result);
},
Err(err) => {
let _ = return_rx.send(Err(AppError::Internal(anyhow!(
"fail to calculate missing update: {}",
err
))));
},
}
} else {
let _ = return_rx.send(Ok(None));
}
}
},
Ok(Err(err)) => {
if let Some(return_rx) = return_tx {
let _ = return_rx.send(Err(AppError::Internal(anyhow!(
"apply http update to group fail: {}",
err
))));
} else {
error!("apply http update to group fail: {}", err);
}
},
Err(err) => {
if let Some(return_rx) = return_tx {
let _ = return_rx.send(Err(AppError::Internal(anyhow!(
"fail to receive applied result: {}",
err
))));
} else {
error!("fail to receive applied result: {}", err);
}
},
}
});
Ok(())
}
#[inline]
fn create_group_if_not_exist(&self, object_id: &str) -> Sender<GroupCommand> {
let old_sender = self
.group_sender_by_object_id
.get(object_id)
.map(|entry| entry.value().clone());
let sender = match old_sender {
Some(sender) => sender,
None => match self.group_sender_by_object_id.entry(object_id.to_string()) {
Entry::Occupied(entry) => entry.get().clone(),
Entry::Vacant(entry) => {
let (new_sender, recv) = tokio::sync::mpsc::channel(2000);
let runner = GroupCommandRunner {
group_manager: self.group_manager.clone(),
msg_router_by_user: self.connect_state.client_message_routers.clone(),
recv: Some(recv),
};
let object_id = entry.key().clone();
if self.enable_custom_runtime {
COLLAB_RUNTIME.spawn(runner.run(object_id));
} else {
tokio::spawn(runner.run(object_id));
}
entry.insert(new_sender.clone());
new_sender
},
},
};
sender
}
#[inline]
pub fn handle_client_generate_embedding_request(
&self,
message: ClientGenerateEmbeddingMessage,
) -> Result<(), RealtimeError> {
let group_cmd_sender = self.create_group_if_not_exist(&message.object_id);
tokio::spawn(async move {
let (tx, rx) = tokio::sync::oneshot::channel();
let result = group_cmd_sender
.send(GroupCommand::GenerateCollabEmbedding {
object_id: message.object_id,
ret: tx,
})
.await;
if let Err(err) = result {
if let Some(return_tx) = message.return_tx {
let _ = return_tx.send(Err(AppError::Internal(anyhow!(
"send generate embedding to group fail: {}",
err
))));
return;
} else {
error!("send generate embedding to group fail: {}", err);
}
}
match rx.await {
Ok(Ok(())) => {
if let Some(return_tx) = message.return_tx {
let _ = return_tx.send(Ok(()));
}
},
Ok(Err(err)) => {
if let Some(return_tx) = message.return_tx {
let _ = return_tx.send(Err(AppError::Internal(anyhow!(
"generate embedding fail: {}",
err
))));
} else {
error!("generate embedding fail: {}", err);
}
},
Err(err) => {
if let Some(return_tx) = message.return_tx {
let _ = return_tx.send(Err(AppError::Internal(anyhow!(
"fail to receive generate embedding result: {}",
err
))));
} else {
error!("fail to receive generate embedding result: {}", err);
}
},
}
});
Ok(())
}
pub fn get_user_by_device(&self, user_device: &UserDevice) -> Option<RealtimeUser> {
self
.connect_state

View file

@ -1,4 +1,6 @@
use appflowy_collaborate::indexer::DocumentDataExt;
use collab::core::origin::CollabOrigin;
use collab::preclude::Collab;
use collab_document::document::Document;
use workspace_template::document::getting_started::{
get_initial_document_data, getting_started_document_data,
};
@ -6,7 +8,9 @@ use workspace_template::document::getting_started::{
#[test]
fn document_plain_text() {
let doc = getting_started_document_data().unwrap();
let text = doc.to_plain_text();
let collab = Collab::new_with_origin(CollabOrigin::Server, "1", vec![], false);
let document = Document::create_with_data(collab, doc).unwrap();
let text = document.to_plain_text(false).unwrap();
let expected = "Welcome to AppFlowy $ Download for macOS, Windows, and Linux link $ $ quick start Ask AI powered by advanced AI models: chat, search, write, and much more ✨ ❤\u{fe0f}Love AppFlowy and open source? Follow our latest product updates: Twitter : @appflowy Reddit : r/appflowy Github ";
assert_eq!(&text, expected);
}
@ -14,7 +18,9 @@ fn document_plain_text() {
#[test]
fn document_plain_text_with_nested_blocks() {
let doc = get_initial_document_data().unwrap();
let text = doc.to_plain_text();
let collab = Collab::new_with_origin(CollabOrigin::Server, "1", vec![], false);
let document = Document::create_with_data(collab, doc).unwrap();
let text = document.to_plain_text(false).unwrap();
let expected = "Welcome to AppFlowy! Here are the basics Here is H3 Click anywhere and just start typing. Click Enter to create a new line. Highlight any text, and use the editing menu to style your writing however you like. As soon as you type / a menu will pop up. Select different types of content blocks you can add. Type / followed by /bullet or /num to create a list. Click + New Page button at the bottom of your sidebar to add a new page. Click + next to any page title in the sidebar to quickly add a new subpage, Document , Grid , or Kanban Board . Keyboard shortcuts, markdown, and code block Keyboard shortcuts guide Markdown reference Type /code to insert a code block // This is the main function.\nfn main() {\n // Print text to the console.\n println!(\"Hello World!\");\n} This is a paragraph This is a paragraph Have a question❓ Click ? at the bottom right for help and support. This is a paragraph This is a paragraph Click ? at the bottom right for help and support. Like AppFlowy? Follow us: GitHub Twitter : @appflowy Newsletter ";
assert_eq!(&text, expected);
}

View file

@ -145,6 +145,7 @@ pub struct AppState {
pub redis_client: ConnectionManager,
pub pg_pool: PgPool,
pub s3_client: S3ClientImpl,
#[allow(dead_code)]
pub mailer: AFWorkerMailer,
pub metrics: AppMetrics,
}

View file

@ -52,18 +52,38 @@ pub fn compress_type_from_header_value(headers: &HeaderMap) -> Result<Compressio
}
}
pub fn device_id_from_headers(headers: &HeaderMap) -> Result<String, AppError> {
headers
.get("device_id")
.ok_or(AppError::InvalidRequest(
"Missing device_id header".to_string(),
))
fn value_from_headers<'a>(
headers: &'a HeaderMap,
keys: &[&str],
missing_msg: &str,
) -> Result<&'a str, AppError> {
keys
.iter()
.find_map(|key| headers.get(*key))
.ok_or_else(|| AppError::InvalidRequest(missing_msg.to_string()))
.and_then(|header| {
header
.to_str()
.map_err(|err| AppError::InvalidRequest(format!("Failed to parse device_id: {}", err)))
.map_err(|err| AppError::InvalidRequest(format!("Failed to parse header: {}", err)))
})
.map(|s| s.to_string())
}
/// Retrieve client version from headers
pub fn client_version_from_headers(headers: &HeaderMap) -> Result<&str, AppError> {
value_from_headers(
headers,
&["Client-Version", "client-version", "client_version"],
"Missing Client-Version or client-version header",
)
}
/// Retrieve device ID from headers
pub fn device_id_from_headers(headers: &HeaderMap) -> Result<&str, AppError> {
value_from_headers(
headers,
&["Device-Id", "device-id", "device_id", "Device-ID"],
"Missing Device-Id or device_id header",
)
}
#[async_trait]
@ -181,3 +201,114 @@ pub(crate) fn ai_model_from_header(req: &HttpRequest) -> AIModel {
})
.unwrap_or(AIModel::GPT4oMini)
}
#[cfg(test)]
mod tests {
use super::*;
use actix_http::header::{HeaderMap, HeaderName, HeaderValue};
fn setup_headers(key: &str, value: &str) -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_str(key).unwrap(),
HeaderValue::from_str(value).unwrap(),
);
headers
}
#[test]
fn test_client_version_valid_variations() {
let test_cases = [
("Client-Version", "1.0.0"),
("client-version", "2.0.0"),
("client_version", "3.0.0"),
];
for (key, value) in test_cases.iter() {
let headers = setup_headers(key, value);
let result = client_version_from_headers(&headers);
assert!(result.is_ok());
assert_eq!(result.unwrap(), *value);
}
}
#[test]
fn test_device_id_valid_variations() {
let test_cases = [
("Device-Id", "device123"),
("device-id", "device456"),
("device_id", "device789"),
("Device-ID", "device000"),
];
for (key, value) in test_cases.iter() {
let headers = setup_headers(key, value);
let result = device_id_from_headers(&headers);
assert!(result.is_ok());
assert_eq!(result.unwrap(), *value);
}
}
#[test]
fn test_missing_client_version() {
let headers = HeaderMap::new();
let result = client_version_from_headers(&headers);
assert!(result.is_err());
match result {
Err(AppError::InvalidRequest(msg)) => {
assert_eq!(msg, "Missing Client-Version or client-version header");
},
_ => panic!("Expected InvalidRequest error"),
}
}
#[test]
fn test_missing_device_id() {
let headers = HeaderMap::new();
let result = device_id_from_headers(&headers);
assert!(result.is_err());
match result {
Err(AppError::InvalidRequest(msg)) => {
assert_eq!(msg, "Missing Device-Id or device_id header");
},
_ => panic!("Expected InvalidRequest error"),
}
}
#[test]
fn test_invalid_header_value() {
let mut headers = HeaderMap::new();
// Create an invalid UTF-8 header value
headers.insert(
HeaderName::from_str("Client-Version").unwrap(),
HeaderValue::from_bytes(&[0xFF, 0xFF]).unwrap(),
);
let result = client_version_from_headers(&headers);
assert!(result.is_err());
match result {
Err(AppError::InvalidRequest(msg)) => {
assert!(msg.starts_with("Failed to parse header:"));
},
_ => panic!("Expected InvalidRequest error"),
}
}
#[test]
fn test_value_from_headers_multiple_keys_present() {
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_str("key1").unwrap(),
HeaderValue::from_static("value1"),
);
headers.insert(
HeaderName::from_str("key2").unwrap(),
HeaderValue::from_static("value2"),
);
let result = value_from_headers(&headers, &["key1", "key2"], "Missing key");
assert!(result.is_ok());
// Should return the first matching key's value
assert_eq!(result.unwrap(), "value1");
}
}

View file

@ -1,44 +1,4 @@
use access_control::act::Action;
use actix_web::web::{Bytes, Path, Payload};
use actix_web::web::{Data, Json, PayloadConfig};
use actix_web::{web, Scope};
use actix_web::{HttpRequest, Result};
use anyhow::{anyhow, Context};
use bytes::BytesMut;
use chrono::{DateTime, Duration, Utc};
use collab::entity::EncodedCollab;
use collab_database::entity::FieldType;
use collab_entity::CollabType;
use futures_util::future::try_join_all;
use prost::Message as ProstMessage;
use rayon::prelude::*;
use sqlx::types::uuid;
use std::collections::HashMap;
use std::time::Instant;
use tokio_stream::StreamExt;
use tokio_tungstenite::tungstenite::Message;
use tracing::{error, event, instrument, trace};
use uuid::Uuid;
use validator::Validate;
use app_error::AppError;
use appflowy_collaborate::actix_ws::entities::ClientStreamMessage;
use appflowy_collaborate::indexer::IndexerProvider;
use authentication::jwt::{Authorization, OptionalUserUuid, UserUuid};
use collab_rt_entity::realtime_proto::HttpRealtimeMessage;
use collab_rt_entity::RealtimeMessage;
use collab_rt_protocol::validate_encode_collab;
use database::collab::{CollabStorage, GetCollabOrigin};
use database::user::select_uid_from_email;
use database_entity::dto::PublishCollabItem;
use database_entity::dto::PublishInfo;
use database_entity::dto::*;
use shared_entity::dto::workspace_dto::*;
use shared_entity::response::AppResponseError;
use shared_entity::response::{AppResponse, JsonAppResponse};
use crate::api::util::PayloadReader;
use crate::api::util::{client_version_from_headers, PayloadReader};
use crate::api::util::{compress_type_from_header_value, device_id_from_headers, CollabValidator};
use crate::api::ws::RealtimeServerAddr;
use crate::biz;
@ -61,7 +21,47 @@ use crate::domain::compression::{
blocking_decompress, decompress, CompressionType, X_COMPRESSION_TYPE,
};
use crate::state::AppState;
use access_control::act::Action;
use actix_web::web::{Bytes, Path, Payload};
use actix_web::web::{Data, Json, PayloadConfig};
use actix_web::{web, HttpResponse, ResponseError, Scope};
use actix_web::{HttpRequest, Result};
use anyhow::{anyhow, Context};
use app_error::AppError;
use appflowy_collaborate::actix_ws::entities::{ClientHttpStreamMessage, ClientHttpUpdateMessage};
use appflowy_collaborate::indexer::IndexerProvider;
use authentication::jwt::{Authorization, OptionalUserUuid, UserUuid};
use bytes::BytesMut;
use chrono::{DateTime, Duration, Utc};
use collab::entity::EncodedCollab;
use collab_database::entity::FieldType;
use collab_entity::CollabType;
use collab_folder::timestamp;
use collab_rt_entity::collab_proto::{CollabDocStateParams, PayloadCompressionType};
use collab_rt_entity::realtime_proto::HttpRealtimeMessage;
use collab_rt_entity::user::RealtimeUser;
use collab_rt_entity::RealtimeMessage;
use collab_rt_protocol::validate_encode_collab;
use database::collab::{CollabStorage, GetCollabOrigin};
use database::user::select_uid_from_email;
use database_entity::dto::PublishCollabItem;
use database_entity::dto::PublishInfo;
use database_entity::dto::*;
use futures_util::future::try_join_all;
use prost::Message as ProstMessage;
use rayon::prelude::*;
use shared_entity::dto::workspace_dto::*;
use shared_entity::response::AppResponseError;
use shared_entity::response::{AppResponse, JsonAppResponse};
use sqlx::types::uuid;
use std::collections::HashMap;
use std::io::Cursor;
use std::time::Instant;
use tokio_stream::StreamExt;
use tokio_tungstenite::tungstenite::Message;
use tracing::{error, event, instrument, trace};
use uuid::Uuid;
use validator::Validate;
pub const WORKSPACE_ID_PATH: &str = "workspace_id";
pub const COLLAB_OBJECT_ID_PATH: &str = "object_id";
@ -127,10 +127,25 @@ pub fn workspace_scope() -> Scope {
web::resource("/v1/{workspace_id}/collab/{object_id}")
.route(web::get().to(v1_get_collab_handler)),
)
.service(
web::resource("/v1/{workspace_id}/collab/{object_id}/sync")
.route(web::post().to(collab_two_way_sync_handler)),
)
.service(
web::resource("/v1/{workspace_id}/collab/{object_id}/web-update")
.route(web::post().to(post_web_update_handler)),
)
.service(
web::resource("/{workspace_id}/collab/{object_id}/member")
.route(web::post().to(add_collab_member_handler))
.route(web::get().to(get_collab_member_handler))
.route(web::put().to(update_collab_member_handler))
.route(web::delete().to(remove_collab_member_handler)),
)
.service(
web::resource("/{workspace_id}/collab/{object_id}/info")
.route(web::get().to(get_collab_info_handler)),
)
.service(web::resource("/{workspace_id}/space").route(web::post().to(post_space_handler)))
.service(
web::resource("/{workspace_id}/space/{view_id}").route(web::patch().to(update_space_handler)),
@ -175,13 +190,6 @@ pub fn workspace_scope() -> Scope {
web::resource("/{workspace_id}/{object_id}/snapshot/list")
.route(web::get().to(get_all_collab_snapshot_list_handler)),
)
.service(
web::resource("/{workspace_id}/collab/{object_id}/member")
.route(web::post().to(add_collab_member_handler))
.route(web::get().to(get_collab_member_handler))
.route(web::put().to(update_collab_member_handler))
.route(web::delete().to(remove_collab_member_handler)),
)
.service(
web::resource("/published/{publish_namespace}")
.route(web::get().to(get_default_published_collab_info_meta_handler)),
@ -907,12 +915,24 @@ async fn v1_get_collab_handler(
Ok(Json(AppResponse::Ok().with_data(resp)))
}
#[instrument(level = "debug", skip_all)]
async fn post_web_update_handler(
user_uuid: UserUuid,
path: web::Path<(Uuid, Uuid)>,
payload: Json<UpdateCollabWebParams>,
state: Data<AppState>,
server: Data<RealtimeServerAddr>,
req: HttpRequest,
) -> Result<Json<AppResponse<()>>> {
let payload = payload.into_inner();
let app_version = client_version_from_headers(req.headers())
.map(|s| s.to_string())
.unwrap_or_else(|_| "web".to_string());
let device_id = device_id_from_headers(req.headers())
.map(|s| s.to_string())
.unwrap_or_else(|_| Uuid::new_v4().to_string());
let session_id = device_id.clone();
let (workspace_id, object_id) = path.into_inner();
let collab_type = payload.collab_type.clone();
let uid = state
@ -920,15 +940,24 @@ async fn post_web_update_handler(
.get_user_uid(&user_uuid)
.await
.map_err(AppResponseError::from)?;
update_page_collab_data(
&state.pg_pool,
state.collab_access_control_storage.clone(),
state.metrics.appflowy_web_metrics.clone(),
let user = RealtimeUser {
uid,
device_id,
connect_at: timestamp(),
session_id,
app_version,
};
trace!("create onetime web realtime user: {}", user);
update_page_collab_data(
&state.metrics.appflowy_web_metrics,
server,
user,
workspace_id,
object_id,
collab_type,
&payload.doc_state,
payload.doc_state,
)
.await?;
Ok(Json(AppResponse::Ok()))
@ -1225,10 +1254,7 @@ async fn update_collab_handler(
let create_params = CreateCollabParams::from((workspace_id.to_string(), params));
let (mut params, workspace_id) = create_params.split();
if let Some(indexer) = state
.indexer_provider
.indexer_for(params.collab_type.clone())
{
if let Some(indexer) = state.indexer_provider.indexer_for(&params.collab_type) {
if state
.indexer_provider
.can_index_workspace(&workspace_id)
@ -1347,7 +1373,7 @@ async fn update_collab_member_handler(
}
#[instrument(level = "debug", skip(state, payload), err)]
async fn get_collab_member_handler(
payload: Json<CollabMemberIdentify>,
payload: Json<WorkspaceCollabIdentify>,
state: Data<AppState>,
) -> Result<Json<AppResponse<AFCollabMember>>> {
let payload = payload.into_inner();
@ -1357,7 +1383,7 @@ async fn get_collab_member_handler(
#[instrument(skip(state, payload), err)]
async fn remove_collab_member_handler(
payload: Json<CollabMemberIdentify>,
payload: Json<WorkspaceCollabIdentify>,
state: Data<AppState>,
) -> Result<Json<AppResponse<()>>> {
let payload = payload.into_inner();
@ -1750,8 +1776,9 @@ async fn post_realtime_message_stream_handler(
state: Data<AppState>,
req: HttpRequest,
) -> Result<Json<AppResponse<()>>> {
// TODO(nathan): after upgrade the client application, then the device_id should not be empty
let device_id = device_id_from_headers(req.headers()).unwrap_or_else(|_| "".to_string());
let device_id = device_id_from_headers(req.headers())
.map(|s| s.to_string())
.unwrap_or_else(|_| "".to_string());
let uid = state
.user_cache
.get_user_uid(&user_uuid)
@ -1767,7 +1794,7 @@ async fn post_realtime_message_stream_handler(
let device_id = device_id.to_string();
let message = parser_realtime_msg(bytes.freeze(), req.clone()).await?;
let stream_message = ClientStreamMessage {
let stream_message = ClientHttpStreamMessage {
uid,
device_id,
message,
@ -2159,3 +2186,130 @@ async fn fetch_embeddings(
Ok(())
}
#[instrument(level = "debug", skip(state, payload), err)]
async fn get_collab_info_handler(
payload: Json<WorkspaceCollabIdentify>,
query: web::Query<CollabTypeParam>,
state: Data<AppState>,
) -> Result<Json<AppResponse<AFCollabInfo>>> {
let payload = payload.into_inner();
let collab_type = query.into_inner().collab_type;
let info = database::collab::get_collab_info(&state.pg_pool, &payload.object_id, collab_type)
.await
.map_err(AppResponseError::from)?
.ok_or_else(|| {
AppError::RecordNotFound(format!(
"Collab with object_id {} not found",
payload.object_id
))
})?;
Ok(Json(AppResponse::Ok().with_data(info)))
}
#[instrument(level = "debug", skip_all, err)]
async fn collab_two_way_sync_handler(
user_uuid: UserUuid,
body: Bytes,
path: web::Path<(Uuid, Uuid)>,
state: Data<AppState>,
server: Data<RealtimeServerAddr>,
req: HttpRequest,
) -> Result<HttpResponse> {
if body.is_empty() {
return Err(AppError::InvalidRequest("body is empty".to_string()).into());
}
// when the payload size exceeds the limit, we consider it as an invalid payload.
const MAX_BODY_SIZE: usize = 1024 * 1024 * 50; // 50MB
if body.len() > MAX_BODY_SIZE {
error!("Unexpected large body size: {}", body.len());
return Err(
AppError::InvalidRequest(format!("body size exceeds limit: {}", MAX_BODY_SIZE)).into(),
);
}
let (workspace_id, object_id) = path.into_inner();
let params = CollabDocStateParams::decode(&mut Cursor::new(body)).map_err(|err| {
AppError::InvalidRequest(format!("Failed to parse CreateCollabEmbedding: {}", err))
})?;
if params.doc_state.is_empty() {
return Err(AppError::InvalidRequest("doc state is empty".to_string()).into());
}
let collab_type = CollabType::from(params.collab_type);
let compression_type = PayloadCompressionType::try_from(params.compression).map_err(|err| {
AppError::InvalidRequest(format!("Failed to parse PayloadCompressionType: {}", err))
})?;
let doc_state = match compression_type {
PayloadCompressionType::None => params.doc_state,
PayloadCompressionType::Zstd => tokio::task::spawn_blocking(move || {
zstd::decode_all(&*params.doc_state)
.map_err(|err| AppError::InvalidRequest(format!("Failed to decompress doc_state: {}", err)))
})
.await
.map_err(AppError::from)??,
};
let sv = match compression_type {
PayloadCompressionType::None => params.sv,
PayloadCompressionType::Zstd => tokio::task::spawn_blocking(move || {
zstd::decode_all(&*params.sv)
.map_err(|err| AppError::InvalidRequest(format!("Failed to decompress sv: {}", err)))
})
.await
.map_err(AppError::from)??,
};
let app_version = client_version_from_headers(req.headers())
.map(|s| s.to_string())
.unwrap_or_else(|_| "".to_string());
let device_id = device_id_from_headers(req.headers())
.map(|s| s.to_string())
.unwrap_or_else(|_| "".to_string());
let uid = state
.user_cache
.get_user_uid(&user_uuid)
.await
.map_err(AppResponseError::from)?;
let user = RealtimeUser {
uid,
device_id,
connect_at: timestamp(),
session_id: uuid::Uuid::new_v4().to_string(),
app_version,
};
let (tx, rx) = tokio::sync::oneshot::channel();
let message = ClientHttpUpdateMessage {
user,
workspace_id: workspace_id.to_string(),
object_id: object_id.to_string(),
collab_type,
update: Bytes::from(doc_state),
state_vector: Some(Bytes::from(sv)),
return_tx: Some(tx),
};
server
.try_send(message)
.map_err(|err| AppError::Internal(anyhow!("Failed to send message to server: {}", err)))?;
match rx
.await
.map_err(|err| AppError::Internal(anyhow!("Failed to receive message from server: {}", err)))?
{
Ok(Some(data)) => {
let encoded = tokio::task::spawn_blocking(move || zstd::encode_all(Cursor::new(data), 3))
.await
.map_err(|err| AppError::Internal(anyhow!("Failed to compress data: {}", err)))??;
Ok(HttpResponse::Ok().body(encoded))
},
Ok(None) => Ok(HttpResponse::InternalServerError().finish()),
Err(err) => Ok(err.error_response()),
}
}

View file

@ -28,9 +28,9 @@ use database::collab::select_workspace_database_oid;
use database::collab::{CollabStorage, GetCollabOrigin};
use database::publish::select_published_view_ids_for_workspace;
use database::publish::select_workspace_id_for_publish_namespace;
use database_entity::dto::CollabParams;
use database_entity::dto::QueryCollab;
use database_entity::dto::QueryCollabResult;
use database_entity::dto::{CollabParams, WorkspaceCollabIdentify};
use shared_entity::dto::workspace_dto::AFDatabase;
use shared_entity::dto::workspace_dto::AFDatabaseField;
use shared_entity::dto::workspace_dto::AFDatabaseRow;
@ -44,23 +44,19 @@ use shared_entity::dto::workspace_dto::TrashFolderView;
use sqlx::PgPool;
use std::ops::DerefMut;
use crate::biz::collab::utils::field_by_name_uniq;
use crate::biz::workspace::ops::broadcast_update;
use access_control::collab::CollabAccessControl;
use anyhow::Context;
use database_entity::dto::{
AFCollabMember, InsertCollabMemberParams, QueryCollabMembers, UpdateCollabMemberParams,
};
use shared_entity::dto::workspace_dto::{FolderView, PublishedView};
use sqlx::types::Uuid;
use std::collections::HashSet;
use tracing::{event, trace};
use validator::Validate;
use access_control::collab::CollabAccessControl;
use database_entity::dto::{
AFCollabMember, CollabMemberIdentify, InsertCollabMemberParams, QueryCollabMembers,
UpdateCollabMemberParams,
};
use crate::biz::collab::utils::field_by_name_uniq;
use crate::biz::workspace::ops::broadcast_update;
use super::folder_view::collab_folder_to_folder_view;
use super::folder_view::section_items_to_favorite_folder_view;
use super::folder_view::section_items_to_recent_folder_view;
@ -159,7 +155,7 @@ pub async fn upsert_collab_member(
pub async fn get_collab_member(
pg_pool: &PgPool,
params: &CollabMemberIdentify,
params: &WorkspaceCollabIdentify,
) -> Result<AFCollabMember, AppError> {
params.validate()?;
let collab_member =
@ -169,7 +165,7 @@ pub async fn get_collab_member(
pub async fn delete_collab_member(
pg_pool: &PgPool,
params: &CollabMemberIdentify,
params: &WorkspaceCollabIdentify,
collab_access_control: Arc<dyn CollabAccessControl>,
) -> Result<(), AppError> {
params.validate()?;

View file

@ -19,7 +19,7 @@ use yrs::updates::encoder::Encode;
use access_control::workspace::WorkspaceAccessControl;
use app_error::AppError;
use appflowy_collaborate::collab::storage::CollabAccessControlStorage;
use database::collab::{upsert_collab_member_with_txn, CollabStorage};
use database::collab::upsert_collab_member_with_txn;
use database::file::s3_client_impl::S3BucketStorage;
use database::pg_row::AFWorkspaceMemberRow;

View file

@ -1,6 +1,18 @@
use super::ops::broadcast_update;
use crate::api::metrics::AppFlowyWebMetrics;
use crate::api::ws::RealtimeServerAddr;
use crate::biz::collab::folder_view::{
check_if_view_is_space, parse_extra_field_as_json, to_dto_view_icon, to_dto_view_layout,
to_folder_view_icon, to_space_permission,
};
use crate::biz::collab::ops::{get_latest_collab_folder, get_latest_workspace_database};
use crate::biz::collab::utils::{collab_from_doc_state, get_latest_collab_encoded};
use actix_web::web::Data;
use anyhow::anyhow;
use app_error::AppError;
use appflowy_collaborate::actix_ws::entities::ClientHttpUpdateMessage;
use appflowy_collaborate::collab::storage::CollabAccessControlStorage;
use bytes::Bytes;
use chrono::DateTime;
use collab::core::collab::Collab;
use collab_database::database::{
@ -24,10 +36,11 @@ use collab_document::document_data::default_document_data;
use collab_entity::{CollabType, EncodedCollab};
use collab_folder::hierarchy_builder::NestedChildViewBuilder;
use collab_folder::{timestamp, CollabOrigin, Folder};
use collab_rt_entity::user::RealtimeUser;
use database::collab::{select_workspace_database_oid, CollabStorage, GetCollabOrigin};
use database::publish::select_published_view_ids_for_workspace;
use database::user::select_web_user_from_uid;
use database_entity::dto::{CollabParams, QueryCollab, QueryCollabParams, QueryCollabResult};
use database_entity::dto::{CollabParams, QueryCollab, QueryCollabResult};
use itertools::Itertools;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use serde_json::json;
@ -38,20 +51,8 @@ use sqlx::{PgPool, Transaction};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Instant;
use tracing::instrument;
use uuid::Uuid;
use yrs::updates::decoder::Decode;
use yrs::Update;
use crate::api::metrics::AppFlowyWebMetrics;
use crate::biz::collab::folder_view::{
parse_extra_field_as_json, to_dto_view_icon, to_dto_view_layout, to_folder_view_icon,
to_space_permission,
};
use crate::biz::collab::ops::get_latest_workspace_database;
use crate::biz::collab::utils::{collab_from_doc_state, get_latest_collab_encoded};
use crate::biz::collab::{folder_view::check_if_view_is_space, ops::get_latest_collab_folder};
use super::ops::broadcast_update;
struct WorkspaceDatabaseUpdate {
pub updated_encoded_collab: Vec<u8>,
@ -1275,64 +1276,32 @@ async fn get_page_collab_data_for_document(
})
}
#[allow(clippy::too_many_arguments)]
#[instrument(level = "debug", skip_all)]
pub async fn update_page_collab_data(
pg_pool: &PgPool,
collab_access_control_storage: Arc<CollabAccessControlStorage>,
appflowy_web_metrics: Arc<AppFlowyWebMetrics>,
uid: i64,
appflowy_web_metrics: &Arc<AppFlowyWebMetrics>,
server: Data<RealtimeServerAddr>,
user: RealtimeUser,
workspace_id: Uuid,
object_id: Uuid,
collab_type: CollabType,
doc_state: &[u8],
doc_state: Vec<u8>,
) -> Result<(), AppError> {
let param = QueryCollabParams {
workspace_id: workspace_id.to_string(),
inner: QueryCollab {
object_id: object_id.to_string(),
collab_type: collab_type.clone(),
},
};
let encode_collab = collab_access_control_storage
.get_encode_collab(GetCollabOrigin::User { uid }, param, true)
.await?;
let mut collab = collab_from_doc_state(encode_collab.doc_state.to_vec(), &object_id.to_string())?;
let object_id = object_id.to_string();
appflowy_web_metrics.record_update_size_bytes(doc_state.len());
let update = Update::decode_v1(doc_state).map_err(|e| {
appflowy_web_metrics.incr_decoding_failure_count(1);
AppError::InvalidRequest(format!("Failed to decode update: {}", e))
})?;
collab.apply_update(update).map_err(|e| {
appflowy_web_metrics.incr_apply_update_failure_count(1);
AppError::InvalidRequest(format!("Failed to apply update: {}", e))
})?;
let updated_encoded_collab = collab
.encode_collab_v1(|c| collab_type.validate_require_data(c))
.map_err(|e| AppError::Internal(anyhow!("Failed to encode collab: {}", e)))?
.encode_to_bytes()?;
let params = CollabParams {
let message = ClientHttpUpdateMessage {
user,
workspace_id: workspace_id.to_string(),
object_id: object_id.to_string(),
collab_type: collab_type.clone(),
encoded_collab_v1: updated_encoded_collab.into(),
embeddings: None,
collab_type,
update: Bytes::from(doc_state),
state_vector: None,
return_tx: None,
};
let mut transaction = pg_pool.begin().await?;
collab_access_control_storage
.upsert_new_collab_with_transaction(
&workspace_id.to_string(),
&uid,
params,
&mut transaction,
"upsert collab",
)
.await?;
transaction.commit().await?;
broadcast_update(
&collab_access_control_storage,
&object_id.to_string(),
doc_state.to_vec(),
)
.await?;
server
.try_send(message)
.map_err(|err| AppError::Internal(anyhow!("Failed to send message to server: {}", err)))?;
Ok(())
}

View file

@ -2,6 +2,7 @@ use actix_http::header::{HeaderName, HeaderValue};
use std::future::{ready, Ready};
use tracing::{span, Instrument, Level};
use crate::api::util::{client_version_from_headers, device_id_from_headers};
use actix_service::{forward_ready, Service, Transform};
use actix_web::dev::{ServiceRequest, ServiceResponse};
use futures_util::future::LocalBoxFuture;
@ -106,15 +107,8 @@ fn get_client_info(req: &ServiceRequest) -> ClientInfo {
.and_then(|val| val.parse::<usize>().ok())
.unwrap_or_default();
let client_version = req
.headers()
.get("client-version")
.and_then(|val| val.to_str().ok());
let device_id = req
.headers()
.get("device_id")
.and_then(|val| val.to_str().ok());
let client_version = client_version_from_headers(req.headers()).ok();
let device_id = device_id_from_headers(req.headers()).ok();
ClientInfo {
payload_size,

View file

@ -3,8 +3,8 @@ use client_api_test::{generate_unique_registered_user_client, workspace_id_from_
use collab_entity::CollabType;
use database_entity::dto::{
AFAccessLevel, CollabMemberIdentify, CreateCollabParams, InsertCollabMemberParams,
QueryCollabMembers, UpdateCollabMemberParams,
AFAccessLevel, CreateCollabParams, InsertCollabMemberParams, QueryCollabMembers,
UpdateCollabMemberParams, WorkspaceCollabIdentify,
};
use uuid::Uuid;
@ -28,7 +28,7 @@ async fn collab_owner_permission_test() {
.unwrap();
let member = c
.get_collab_member(CollabMemberIdentify {
.get_collab_member(WorkspaceCollabIdentify {
uid,
object_id,
workspace_id,
@ -68,7 +68,7 @@ async fn update_collab_member_permission_test() {
.unwrap();
let member = c
.get_collab_member(CollabMemberIdentify {
.get_collab_member(WorkspaceCollabIdentify {
uid,
object_id,
workspace_id,
@ -112,7 +112,7 @@ async fn add_collab_member_test() {
// check the member is added and its permission is correct
let member = c_1
.get_collab_member(CollabMemberIdentify {
.get_collab_member(WorkspaceCollabIdentify {
uid: uid_2,
object_id,
workspace_id,
@ -167,7 +167,7 @@ async fn add_collab_member_then_remove_test() {
// Delete the member
c_1
.remove_collab_member(CollabMemberIdentify {
.remove_collab_member(WorkspaceCollabIdentify {
uid: uid_2,
object_id: object_id.clone(),
workspace_id: workspace_id.clone(),

View file

@ -612,44 +612,163 @@ async fn simulate_10_offline_user_connect_and_then_sync_document_test() {
}
}
// #[tokio::test]
// async fn simulate_50_user_connect_and_then_sync_document_test() {
// let users = Arc::new(RwLock::new(vec![]));
// let mut tasks = vec![];
// for i in 0..50 {
// let task = tokio::spawn(async move {
// let new_user = TestClient::new_user().await;
// // sleep to make sure it do not trigger register user too fast in gotrue
// sleep(Duration::from_secs(i % 5)).await;
// new_user
// });
// tasks.push(task);
// }
// let results = futures::future::join_all(tasks).await;
// for result in results {
// users.write().await.push(result.unwrap());
// }
//
// let text = generate_random_string(1024 * 1024 * 3);
// let mut tasks = Vec::new();
// for i in 0..100 {
// let cloned_text = text.clone();
// let cloned_users = users.clone();
// let task = tokio::spawn(async move {
// let object_id = Uuid::new_v4().to_string();
// sleep(Duration::from_secs(1)).await;
// let workspace_id = cloned_users.read().await[i % 50].workspace_id().await;
// let doc_state = make_big_collab_doc_state(&object_id, "text", cloned_text);
// cloned_users.write().await[i % 50]
// .open_collab_with_doc_state(&workspace_id, &object_id, CollabType::Unknown, doc_state)
// .await;
// sleep(Duration::from_secs(6)).await;
// });
// tasks.push(task);
// }
//
// let results = futures::future::join_all(tasks).await;
// for result in results {
// result.unwrap();
// }
// }
#[tokio::test]
async fn offline_and_then_sync_through_http_request() {
let mut test_client = TestClient::new_user().await;
let object_id = Uuid::new_v4().to_string();
let workspace_id = test_client.workspace_id().await;
let doc_state = make_big_collab_doc_state(&object_id, "1", "".to_string());
test_client
.open_collab_with_doc_state(&workspace_id, &object_id, CollabType::Unknown, doc_state)
.await;
test_client
.wait_object_sync_complete(&object_id)
.await
.unwrap();
test_client.disconnect().await;
// Verify server hasn't received small text update while offline
assert_server_collab(
&workspace_id,
&mut test_client.api_client,
&object_id,
&CollabType::Unknown,
10,
json!({"1":""}),
)
.await
.unwrap();
// First insertion - small text
let small_text = generate_random_string(100);
test_client
.insert_into(&object_id, "1", small_text.clone())
.await;
// Sync small text changes
let encode_collab = test_client
.collabs
.get(&object_id)
.unwrap()
.encode_collab()
.await;
test_client
.api_client
.post_collab_doc_state(
&workspace_id,
&object_id,
CollabType::Unknown,
encode_collab.doc_state.to_vec(),
encode_collab.state_vector.to_vec(),
)
.await
.unwrap();
// Verify server still has only small text
assert_server_collab(
&workspace_id,
&mut test_client.api_client,
&object_id,
&CollabType::Unknown,
10,
json!({"1": small_text.clone()}),
)
.await
.unwrap();
// Second insertion - medium text
let medium_text = generate_random_string(512);
test_client
.insert_into(&object_id, "2", medium_text.clone())
.await;
// Sync medium text changes
let encode_collab = test_client
.collabs
.get(&object_id)
.unwrap()
.encode_collab()
.await;
test_client
.api_client
.post_collab_doc_state(
&workspace_id,
&object_id,
CollabType::Unknown,
encode_collab.doc_state.to_vec(),
encode_collab.state_vector.to_vec(),
)
.await
.unwrap();
// Verify medium text was synced
assert_server_collab(
&workspace_id,
&mut test_client.api_client,
&object_id,
&CollabType::Unknown,
10,
json!({"1": small_text, "2": medium_text}),
)
.await
.unwrap();
}
#[tokio::test]
async fn insert_text_through_http_post_request() {
let mut test_client = TestClient::new_user().await;
let object_id = Uuid::new_v4().to_string();
let workspace_id = test_client.workspace_id().await;
let doc_state = make_big_collab_doc_state(&object_id, "1", "".to_string());
test_client
.open_collab_with_doc_state(&workspace_id, &object_id, CollabType::Unknown, doc_state)
.await;
test_client
.wait_object_sync_complete(&object_id)
.await
.unwrap();
test_client.disconnect().await;
let mut final_text = HashMap::new();
for i in 0..1000 {
let key = i.to_string();
let text = generate_random_string(10);
test_client
.insert_into(&object_id, &key, text.clone())
.await;
final_text.insert(key, text);
if i % 100 == 0 {
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
}
}
let encode_collab = test_client
.collabs
.get(&object_id)
.unwrap()
.encode_collab()
.await;
test_client
.api_client
.post_collab_doc_state(
&workspace_id,
&object_id,
CollabType::Unknown,
encode_collab.doc_state.to_vec(),
encode_collab.state_vector.to_vec(),
)
.await
.unwrap();
assert_server_collab(
&workspace_id,
&mut test_client.api_client,
&object_id,
&CollabType::Unknown,
10,
json!(final_text),
)
.await
.unwrap();
}

View file

@ -2,12 +2,11 @@ use actix::{Actor, Context, Handler};
use appflowy_collaborate::actix_ws::client::rt_client::{
HandlerResult, RealtimeClient, RealtimeServer,
};
use appflowy_collaborate::actix_ws::entities::{ClientMessage, Connect, Disconnect};
use appflowy_collaborate::actix_ws::entities::{ClientWebSocketMessage, Connect, Disconnect};
use appflowy_collaborate::error::RealtimeError;
use collab_rt_entity::user::RealtimeUser;
use collab_rt_entity::RealtimeMessage;
use collab_rt_entity::{MessageByObjectId, RealtimeMessage};
use semver::Version;
use std::collections::HashMap;
use std::time::Duration;
#[actix_rt::test]
@ -29,9 +28,10 @@ async fn test_handle_message() {
10,
);
let mut message_by_oid = HashMap::new();
message_by_oid.insert("object_id".to_string(), vec![]);
let message = RealtimeMessage::ClientCollabV2(message_by_oid);
let message = RealtimeMessage::ClientCollabV2(MessageByObjectId::new_with_message(
"object_id".to_string(),
vec![],
));
client.try_send(message).unwrap();
}
@ -64,9 +64,10 @@ async fn server_mailbox_full_test() {
10,
);
for _ in 0..10 {
let mut message_by_oid = HashMap::new();
message_by_oid.insert("object_id".to_string(), vec![]);
let message = RealtimeMessage::ClientCollabV2(message_by_oid);
let message = RealtimeMessage::ClientCollabV2(MessageByObjectId::new_with_message(
"object_id".to_string(),
vec![],
));
client.try_send(message).unwrap();
}
});
@ -110,9 +111,10 @@ async fn client_rate_limit_hit_test() {
1,
);
for _ in 0..10 {
let mut message_by_oid = HashMap::new();
message_by_oid.insert("object_id".to_string(), vec![]);
let message = RealtimeMessage::ClientCollabV2(message_by_oid);
let message = RealtimeMessage::ClientCollabV2(MessageByObjectId::new_with_message(
"object_id".to_string(),
vec![],
));
if let Err(err) = client.try_send(message) {
if err.is_too_many_message() {
continue;
@ -148,10 +150,10 @@ impl Actor for MockRealtimeServer {
}
}
impl Handler<ClientMessage> for MockRealtimeServer {
impl Handler<ClientWebSocketMessage> for MockRealtimeServer {
type Result = HandlerResult;
fn handle(&mut self, _msg: ClientMessage, _ctx: &mut Self::Context) -> Self::Result {
fn handle(&mut self, _msg: ClientWebSocketMessage, _ctx: &mut Self::Context) -> Self::Result {
Ok(())
}
}