chore: Optimize token refresh (#1320)

* fix: potential race condition when drop txs_guard to early

* chore: remove unused tests
This commit is contained in:
Nathan.fooo 2025-04-05 14:30:20 +08:00 committed by GitHub
parent 7a577492da
commit b94d1c60b3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 26 additions and 100 deletions

View file

@ -21,7 +21,6 @@ pub(crate) const SEND_INTERVAL: Duration = Duration::from_secs(8);
pub const COLLAB_SINK_DELAY_MILLIS: u64 = 500;
pub struct CollabSink<Sink> {
#[allow(dead_code)]
uid: i64,
config: SinkConfig,
object: SyncObject,
@ -290,13 +289,10 @@ where
) {
(Some(msg_queue), Some(sending_messages)) => (msg_queue, sending_messages),
_ => {
// If acquire the lock failed, try later
if cfg!(feature = "sync_verbose_log") {
trace!(
"{}: failed to acquire the lock of the sink, retry later",
self.object.object_id
);
}
warn!(
"{}: failed to acquire the lock of the sink, retry later",
self.object.object_id
);
retry_later(Arc::downgrade(&self.notifier));
return;
},

View file

@ -8,7 +8,7 @@ use collab::core::origin::CollabOrigin;
use collab::preclude::Collab;
use futures_util::{SinkExt, StreamExt};
use tokio::sync::{broadcast, watch};
use tracing::{error, instrument, trace};
use tracing::{error, info, instrument, trace};
use yrs::updates::decoder::Decode;
use yrs::updates::encoder::{Encode, Encoder, EncoderV1};
use yrs::{ReadTxn, StateVector};
@ -96,16 +96,12 @@ where
}
pub fn pause(&self) {
if cfg!(feature = "sync_verbose_log") {
trace!("pause {} sync", self.object.object_id);
}
info!("pause {} sync", self.object.object_id);
self.sink.pause();
}
pub fn resume(&self) {
if cfg!(feature = "sync_verbose_log") {
trace!("resume {} sync", self.object.object_id);
}
info!("resume {} sync", self.object.object_id);
self.sink.resume();
}

View file

@ -989,24 +989,34 @@ impl Client {
/// Refreshes the access token using the stored refresh token.
///
/// This function attempts to refresh the access token by sending a request to the authentication server
/// attempts to refresh the access token by sending a request to the authentication server
/// using the stored refresh token. If successful, it updates the stored access token with the new one
/// received from the server.
/// Refreshes the access token using the stored refresh token.
#[instrument(level = "debug", skip_all, err)]
pub async fn refresh_token(&self, reason: &str) -> Result<(), AppResponseError> {
let (tx, rx) = tokio::sync::oneshot::channel();
self.refresh_ret_txs.write().push(tx);
if !self.is_refreshing_token.load(Ordering::SeqCst) {
self.is_refreshing_token.store(true, Ordering::SeqCst);
// Atomically check and set the refreshing flag to prevent race conditions
if self
.is_refreshing_token
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
{
info!("refresh token reason:{}", reason);
let result = self.inner_refresh_token().await;
let txs = std::mem::take(&mut *self.refresh_ret_txs.write());
// Process all pending requests and reset state atomically
let mut txs_guard = self.refresh_ret_txs.write();
let txs = std::mem::take(&mut *txs_guard);
self.is_refreshing_token.store(false, Ordering::SeqCst);
drop(txs_guard);
// Send results to all waiting requests
for tx in txs {
let _ = tx.send(result.clone());
}
self.is_refreshing_token.store(false, Ordering::SeqCst);
} else {
debug!("refresh token is already in progress");
}
@ -1014,11 +1024,10 @@ impl Client {
// Wait for the result of the refresh token request.
match tokio::time::timeout(Duration::from_secs(60), rx).await {
Ok(Ok(result)) => result,
Ok(Err(err)) => Err(AppError::Internal(anyhow!("refresh token error: {}", err)).into()),
Err(_) => {
self.is_refreshing_token.store(false, Ordering::SeqCst);
Err(AppError::RequestTimeout("refresh token timeout".to_string()).into())
Ok(Err(err)) => {
Err(AppError::Internal(anyhow!("refresh token channel error: {}", err)).into())
},
Err(_) => Err(AppError::RequestTimeout("refresh token timeout".to_string()).into()),
}
}

View file

@ -1,2 +0,0 @@
// mod native;
// mod web;

View file

@ -1,47 +0,0 @@
use std::time::SystemTime;
use crate::user::utils::generate_unique_registered_user_client;
use client_api::ws::{ConnectState, WSClient, WSClientConfig};
#[tokio::test]
async fn realtime_connect_test() {
let (c, _user) = generate_unique_registered_user_client().await;
let ws_client = WSClient::new(WSClientConfig::default(), c.clone(), c.clone());
let mut state = ws_client.subscribe_connect_state();
let device_id = "fake_device_id";
loop {
tokio::select! {
_ = ws_client.connect(c.ws_url(device_id).await.unwrap(), device_id) => {},
value = state.recv() => {
let new_state = value.unwrap();
if new_state == ConnectState::Connected {
break;
}
},
}
}
}
#[tokio::test]
async fn realtime_disconnect_test() {
let (c, _user) = generate_unique_registered_user_client().await;
let ws_client = WSClient::new(WSClientConfig::default(), c.clone(), c.clone());
let device_id = "fake_device_id";
ws_client
.connect(c.ws_url(device_id).await.unwrap(), device_id)
.await
.unwrap();
let mut state = ws_client.subscribe_connect_state();
loop {
tokio::select! {
_ = ws_client.disconnect() => {},
value = state.recv() => {
let new_state = value.unwrap();
if new_state == ConnectState::Closed {
break;
}
},
}
}
}

View file

@ -1 +0,0 @@
mod conn_test;

View file

@ -1,22 +0,0 @@
use crate::user::utils::generate_unique_registered_user_client;
use client_api::ws::{ConnectState, WSClient, WSClientConfig};
use wasm_bindgen_test::wasm_bindgen_test;
#[wasm_bindgen_test]
async fn realtime_connect_test() {
let (c, _user) = generate_unique_registered_user_client().await;
let ws_client = WSClient::new(WSClientConfig::default(), c.clone(), c.clone());
let mut state = ws_client.subscribe_connect_state();
let device_id = "fake_device_id";
loop {
tokio::select! {
_ = ws_client.connect(c.ws_url(device_id).await.unwrap(), device_id) => {},
value = state.recv() => {
let new_state = value.unwrap();
if new_state == ConnectState::Connected {
break;
}
},
}
}
}

View file

@ -1,3 +0,0 @@
use wasm_bindgen_test::wasm_bindgen_test_configure;
wasm_bindgen_test_configure!(run_in_browser);
mod conn_test;