mirror of
https://github.com/AppFlowy-IO/AppFlowy.git
synced 2025-04-24 22:57:12 -04:00
add ws test
This commit is contained in:
parent
0b2339aa19
commit
260060ac5c
10 changed files with 235 additions and 71 deletions
|
@ -13,11 +13,13 @@ class ErrorCode extends $pb.ProtobufEnum {
|
||||||
static const ErrorCode InternalError = ErrorCode._(0, const $core.bool.fromEnvironment('protobuf.omit_enum_names') ? '' : 'InternalError');
|
static const ErrorCode InternalError = ErrorCode._(0, const $core.bool.fromEnvironment('protobuf.omit_enum_names') ? '' : 'InternalError');
|
||||||
static const ErrorCode DuplicateSource = ErrorCode._(1, const $core.bool.fromEnvironment('protobuf.omit_enum_names') ? '' : 'DuplicateSource');
|
static const ErrorCode DuplicateSource = ErrorCode._(1, const $core.bool.fromEnvironment('protobuf.omit_enum_names') ? '' : 'DuplicateSource');
|
||||||
static const ErrorCode UnsupportedMessage = ErrorCode._(2, const $core.bool.fromEnvironment('protobuf.omit_enum_names') ? '' : 'UnsupportedMessage');
|
static const ErrorCode UnsupportedMessage = ErrorCode._(2, const $core.bool.fromEnvironment('protobuf.omit_enum_names') ? '' : 'UnsupportedMessage');
|
||||||
|
static const ErrorCode Unauthorized = ErrorCode._(3, const $core.bool.fromEnvironment('protobuf.omit_enum_names') ? '' : 'Unauthorized');
|
||||||
|
|
||||||
static const $core.List<ErrorCode> values = <ErrorCode> [
|
static const $core.List<ErrorCode> values = <ErrorCode> [
|
||||||
InternalError,
|
InternalError,
|
||||||
DuplicateSource,
|
DuplicateSource,
|
||||||
UnsupportedMessage,
|
UnsupportedMessage,
|
||||||
|
Unauthorized,
|
||||||
];
|
];
|
||||||
|
|
||||||
static final $core.Map<$core.int, ErrorCode> _byValue = $pb.ProtobufEnum.initByValue(values);
|
static final $core.Map<$core.int, ErrorCode> _byValue = $pb.ProtobufEnum.initByValue(values);
|
||||||
|
|
|
@ -15,11 +15,12 @@ const ErrorCode$json = const {
|
||||||
const {'1': 'InternalError', '2': 0},
|
const {'1': 'InternalError', '2': 0},
|
||||||
const {'1': 'DuplicateSource', '2': 1},
|
const {'1': 'DuplicateSource', '2': 1},
|
||||||
const {'1': 'UnsupportedMessage', '2': 2},
|
const {'1': 'UnsupportedMessage', '2': 2},
|
||||||
|
const {'1': 'Unauthorized', '2': 3},
|
||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Descriptor for `ErrorCode`. Decode as a `google.protobuf.EnumDescriptorProto`.
|
/// Descriptor for `ErrorCode`. Decode as a `google.protobuf.EnumDescriptorProto`.
|
||||||
final $typed_data.Uint8List errorCodeDescriptor = $convert.base64Decode('CglFcnJvckNvZGUSEQoNSW50ZXJuYWxFcnJvchAAEhMKD0R1cGxpY2F0ZVNvdXJjZRABEhYKElVuc3VwcG9ydGVkTWVzc2FnZRAC');
|
final $typed_data.Uint8List errorCodeDescriptor = $convert.base64Decode('CglFcnJvckNvZGUSEQoNSW50ZXJuYWxFcnJvchAAEhMKD0R1cGxpY2F0ZVNvdXJjZRABEhYKElVuc3VwcG9ydGVkTWVzc2FnZRACEhAKDFVuYXV0aG9yaXplZBAD');
|
||||||
@$core.Deprecated('Use wsErrorDescriptor instead')
|
@$core.Deprecated('Use wsErrorDescriptor instead')
|
||||||
const WsError$json = const {
|
const WsError$json = const {
|
||||||
'1': 'WsError',
|
'1': 'WsError',
|
||||||
|
|
|
@ -83,6 +83,7 @@ name = "backend"
|
||||||
path = "src/main.rs"
|
path = "src/main.rs"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
parking_lot = "0.11"
|
||||||
once_cell = "1.7.2"
|
once_cell = "1.7.2"
|
||||||
linkify = "0.5.0"
|
linkify = "0.5.0"
|
||||||
flowy-user = { path = "../rust-lib/flowy-user" }
|
flowy-user = { path = "../rust-lib/flowy-user" }
|
||||||
|
|
|
@ -1,10 +1,83 @@
|
||||||
use crate::helper::TestServer;
|
use crate::helper::TestServer;
|
||||||
use flowy_ws::WsController;
|
use flowy_ws::{WsController, WsSender, WsState};
|
||||||
|
use parking_lot::RwLock;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
pub struct WsTest {
|
||||||
|
server: TestServer,
|
||||||
|
ws_controller: Arc<RwLock<WsController>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub enum WsScript {
|
||||||
|
SendText(&'static str),
|
||||||
|
SendBinary(Vec<u8>),
|
||||||
|
Disconnect(&'static str),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WsTest {
|
||||||
|
pub async fn new(scripts: Vec<WsScript>) -> Self {
|
||||||
|
let server = TestServer::new().await;
|
||||||
|
let ws_controller = Arc::new(RwLock::new(WsController::new()));
|
||||||
|
ws_controller
|
||||||
|
.write()
|
||||||
|
.state_callback(move |state| match state {
|
||||||
|
WsState::Connected(sender) => {
|
||||||
|
WsScriptRunner {
|
||||||
|
scripts: scripts.clone(),
|
||||||
|
sender: sender.clone(),
|
||||||
|
source: "editor".to_owned(),
|
||||||
|
}
|
||||||
|
.run();
|
||||||
|
},
|
||||||
|
_ => {},
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
Self {
|
||||||
|
server,
|
||||||
|
ws_controller,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_scripts(&mut self) {
|
||||||
|
let addr = self.server.ws_addr();
|
||||||
|
self.ws_controller.write().connect(addr).unwrap().await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct WsScriptRunner {
|
||||||
|
scripts: Vec<WsScript>,
|
||||||
|
sender: Arc<WsSender>,
|
||||||
|
source: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WsScriptRunner {
|
||||||
|
fn run(self) {
|
||||||
|
for script in self.scripts {
|
||||||
|
match script {
|
||||||
|
WsScript::SendText(text) => {
|
||||||
|
self.sender.send_text(&self.source, text).unwrap();
|
||||||
|
},
|
||||||
|
WsScript::SendBinary(bytes) => {
|
||||||
|
self.sender.send_binary(&self.source, bytes).unwrap();
|
||||||
|
},
|
||||||
|
WsScript::Disconnect(reason) => {
|
||||||
|
self.sender.send_disconnect(reason).unwrap();
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[actix_rt::test]
|
#[actix_rt::test]
|
||||||
async fn ws_connect() {
|
async fn ws_connect() {
|
||||||
let server = TestServer::new().await;
|
let mut ws = WsTest::new(vec![
|
||||||
let mut controller = WsController::new();
|
WsScript::SendText("abc"),
|
||||||
let addr = server.ws_addr();
|
WsScript::SendText("abc"),
|
||||||
let _ = controller.connect(addr).unwrap().await;
|
WsScript::SendText("abc"),
|
||||||
|
WsScript::Disconnect("abc"),
|
||||||
|
])
|
||||||
|
.await;
|
||||||
|
ws.run_scripts().await
|
||||||
}
|
}
|
||||||
|
|
|
@ -178,15 +178,16 @@ impl UserSession {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn send_ws_msg<T: Into<WsMessage>>(&self, msg: T) -> Result<(), UserError> {
|
// pub fn send_ws_msg<T: Into<WsMessage>>(&self, msg: T) -> Result<(),
|
||||||
match self.ws_controller.try_read_for(Duration::from_millis(300)) {
|
// UserError> { match self.ws_controller.try_read_for(Duration::
|
||||||
None => Err(UserError::internal().context("Send ws message timeout")),
|
// from_millis(300)) { None =>
|
||||||
Some(guard) => {
|
// Err(UserError::internal().context("Send ws message timeout")),
|
||||||
let _ = guard.send_msg(msg)?;
|
// Some(guard) => {
|
||||||
Ok(())
|
// let _ = guard.send_msg(msg)?;
|
||||||
},
|
// Ok(())
|
||||||
}
|
// },
|
||||||
}
|
// }
|
||||||
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UserSession {
|
impl UserSession {
|
||||||
|
|
|
@ -37,7 +37,7 @@ impl WsConnection {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Future for WsConnection {
|
impl Future for WsConnection {
|
||||||
type Output = Result<WsStream, ServerError>;
|
type Output = Result<WsStream, WsError>;
|
||||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
// [[pin]]
|
// [[pin]]
|
||||||
// poll async function. The following methods not work.
|
// poll async function. The following methods not work.
|
||||||
|
@ -65,7 +65,7 @@ impl Future for WsConnection {
|
||||||
},
|
},
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
log::debug!("🐴 ws connect failed: {:?}", error);
|
log::debug!("🐴 ws connect failed: {:?}", error);
|
||||||
Poll::Ready(Err(error_to_flowy_response(error)))
|
Poll::Ready(Err(error.into()))
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -135,21 +135,6 @@ fn post_message(tx: MsgSender, message: Result<Message, Error>) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn error_to_flowy_response(error: tokio_tungstenite::tungstenite::Error) -> ServerError {
|
|
||||||
let error = match error {
|
|
||||||
Error::Http(response) => {
|
|
||||||
if response.status() == StatusCode::UNAUTHORIZED {
|
|
||||||
ServerError::unauthorized()
|
|
||||||
} else {
|
|
||||||
ServerError::internal().context(response)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
_ => ServerError::internal().context(error),
|
|
||||||
};
|
|
||||||
|
|
||||||
error
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct Retry<F> {
|
pub struct Retry<F> {
|
||||||
f: F,
|
f: F,
|
||||||
retry_time: usize,
|
retry_time: usize,
|
||||||
|
|
|
@ -2,7 +2,7 @@ use flowy_derive::{ProtoBuf, ProtoBuf_Enum};
|
||||||
use futures_channel::mpsc::TrySendError;
|
use futures_channel::mpsc::TrySendError;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use strum_macros::Display;
|
use strum_macros::Display;
|
||||||
use tokio_tungstenite::tungstenite::Message;
|
use tokio_tungstenite::tungstenite::{http::StatusCode, Message};
|
||||||
use url::ParseError;
|
use url::ParseError;
|
||||||
|
|
||||||
#[derive(Debug, Default, Clone, ProtoBuf)]
|
#[derive(Debug, Default, Clone, ProtoBuf)]
|
||||||
|
@ -38,6 +38,7 @@ impl WsError {
|
||||||
static_user_error!(internal, ErrorCode::InternalError);
|
static_user_error!(internal, ErrorCode::InternalError);
|
||||||
static_user_error!(duplicate_source, ErrorCode::DuplicateSource);
|
static_user_error!(duplicate_source, ErrorCode::DuplicateSource);
|
||||||
static_user_error!(unsupported_message, ErrorCode::UnsupportedMessage);
|
static_user_error!(unsupported_message, ErrorCode::UnsupportedMessage);
|
||||||
|
static_user_error!(unauthorized, ErrorCode::Unauthorized);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, ProtoBuf_Enum, Display, PartialEq, Eq)]
|
#[derive(Debug, Clone, ProtoBuf_Enum, Display, PartialEq, Eq)]
|
||||||
|
@ -45,6 +46,7 @@ pub enum ErrorCode {
|
||||||
InternalError = 0,
|
InternalError = 0,
|
||||||
DuplicateSource = 1,
|
DuplicateSource = 1,
|
||||||
UnsupportedMessage = 2,
|
UnsupportedMessage = 2,
|
||||||
|
Unauthorized = 3,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::default::Default for ErrorCode {
|
impl std::default::Default for ErrorCode {
|
||||||
|
@ -64,5 +66,18 @@ impl std::convert::From<futures_channel::mpsc::TrySendError<Message>> for WsErro
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::convert::From<tokio_tungstenite::tungstenite::Error> for WsError {
|
impl std::convert::From<tokio_tungstenite::tungstenite::Error> for WsError {
|
||||||
fn from(error: tokio_tungstenite::tungstenite::Error) -> Self { WsError::internal().context(error) }
|
fn from(error: tokio_tungstenite::tungstenite::Error) -> Self {
|
||||||
|
let error = match error {
|
||||||
|
tokio_tungstenite::tungstenite::Error::Http(response) => {
|
||||||
|
if response.status() == StatusCode::UNAUTHORIZED {
|
||||||
|
WsError::unauthorized()
|
||||||
|
} else {
|
||||||
|
WsError::internal().context(response)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => WsError::internal().context(error),
|
||||||
|
};
|
||||||
|
|
||||||
|
error
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -218,6 +218,7 @@ pub enum ErrorCode {
|
||||||
InternalError = 0,
|
InternalError = 0,
|
||||||
DuplicateSource = 1,
|
DuplicateSource = 1,
|
||||||
UnsupportedMessage = 2,
|
UnsupportedMessage = 2,
|
||||||
|
Unauthorized = 3,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ::protobuf::ProtobufEnum for ErrorCode {
|
impl ::protobuf::ProtobufEnum for ErrorCode {
|
||||||
|
@ -230,6 +231,7 @@ impl ::protobuf::ProtobufEnum for ErrorCode {
|
||||||
0 => ::std::option::Option::Some(ErrorCode::InternalError),
|
0 => ::std::option::Option::Some(ErrorCode::InternalError),
|
||||||
1 => ::std::option::Option::Some(ErrorCode::DuplicateSource),
|
1 => ::std::option::Option::Some(ErrorCode::DuplicateSource),
|
||||||
2 => ::std::option::Option::Some(ErrorCode::UnsupportedMessage),
|
2 => ::std::option::Option::Some(ErrorCode::UnsupportedMessage),
|
||||||
|
3 => ::std::option::Option::Some(ErrorCode::Unauthorized),
|
||||||
_ => ::std::option::Option::None
|
_ => ::std::option::Option::None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -239,6 +241,7 @@ impl ::protobuf::ProtobufEnum for ErrorCode {
|
||||||
ErrorCode::InternalError,
|
ErrorCode::InternalError,
|
||||||
ErrorCode::DuplicateSource,
|
ErrorCode::DuplicateSource,
|
||||||
ErrorCode::UnsupportedMessage,
|
ErrorCode::UnsupportedMessage,
|
||||||
|
ErrorCode::Unauthorized,
|
||||||
];
|
];
|
||||||
values
|
values
|
||||||
}
|
}
|
||||||
|
@ -268,24 +271,26 @@ impl ::protobuf::reflect::ProtobufValue for ErrorCode {
|
||||||
|
|
||||||
static file_descriptor_proto_data: &'static [u8] = b"\
|
static file_descriptor_proto_data: &'static [u8] = b"\
|
||||||
\n\x0cerrors.proto\";\n\x07WsError\x12\x1e\n\x04code\x18\x01\x20\x01(\
|
\n\x0cerrors.proto\";\n\x07WsError\x12\x1e\n\x04code\x18\x01\x20\x01(\
|
||||||
\x0e2\n.ErrorCodeR\x04code\x12\x10\n\x03msg\x18\x02\x20\x01(\tR\x03msg*K\
|
\x0e2\n.ErrorCodeR\x04code\x12\x10\n\x03msg\x18\x02\x20\x01(\tR\x03msg*]\
|
||||||
\n\tErrorCode\x12\x11\n\rInternalError\x10\0\x12\x13\n\x0fDuplicateSourc\
|
\n\tErrorCode\x12\x11\n\rInternalError\x10\0\x12\x13\n\x0fDuplicateSourc\
|
||||||
e\x10\x01\x12\x16\n\x12UnsupportedMessage\x10\x02J\xab\x02\n\x06\x12\x04\
|
e\x10\x01\x12\x16\n\x12UnsupportedMessage\x10\x02\x12\x10\n\x0cUnauthori\
|
||||||
\0\0\n\x01\n\x08\n\x01\x0c\x12\x03\0\0\x12\n\n\n\x02\x04\0\x12\x04\x02\0\
|
zed\x10\x03J\xd4\x02\n\x06\x12\x04\0\0\x0b\x01\n\x08\n\x01\x0c\x12\x03\0\
|
||||||
\x05\x01\n\n\n\x03\x04\0\x01\x12\x03\x02\x08\x0f\n\x0b\n\x04\x04\0\x02\0\
|
\0\x12\n\n\n\x02\x04\0\x12\x04\x02\0\x05\x01\n\n\n\x03\x04\0\x01\x12\x03\
|
||||||
\x12\x03\x03\x04\x17\n\x0c\n\x05\x04\0\x02\0\x06\x12\x03\x03\x04\r\n\x0c\
|
\x02\x08\x0f\n\x0b\n\x04\x04\0\x02\0\x12\x03\x03\x04\x17\n\x0c\n\x05\x04\
|
||||||
\n\x05\x04\0\x02\0\x01\x12\x03\x03\x0e\x12\n\x0c\n\x05\x04\0\x02\0\x03\
|
\0\x02\0\x06\x12\x03\x03\x04\r\n\x0c\n\x05\x04\0\x02\0\x01\x12\x03\x03\
|
||||||
\x12\x03\x03\x15\x16\n\x0b\n\x04\x04\0\x02\x01\x12\x03\x04\x04\x13\n\x0c\
|
\x0e\x12\n\x0c\n\x05\x04\0\x02\0\x03\x12\x03\x03\x15\x16\n\x0b\n\x04\x04\
|
||||||
\n\x05\x04\0\x02\x01\x05\x12\x03\x04\x04\n\n\x0c\n\x05\x04\0\x02\x01\x01\
|
\0\x02\x01\x12\x03\x04\x04\x13\n\x0c\n\x05\x04\0\x02\x01\x05\x12\x03\x04\
|
||||||
\x12\x03\x04\x0b\x0e\n\x0c\n\x05\x04\0\x02\x01\x03\x12\x03\x04\x11\x12\n\
|
\x04\n\n\x0c\n\x05\x04\0\x02\x01\x01\x12\x03\x04\x0b\x0e\n\x0c\n\x05\x04\
|
||||||
\n\n\x02\x05\0\x12\x04\x06\0\n\x01\n\n\n\x03\x05\0\x01\x12\x03\x06\x05\
|
\0\x02\x01\x03\x12\x03\x04\x11\x12\n\n\n\x02\x05\0\x12\x04\x06\0\x0b\x01\
|
||||||
\x0e\n\x0b\n\x04\x05\0\x02\0\x12\x03\x07\x04\x16\n\x0c\n\x05\x05\0\x02\0\
|
\n\n\n\x03\x05\0\x01\x12\x03\x06\x05\x0e\n\x0b\n\x04\x05\0\x02\0\x12\x03\
|
||||||
\x01\x12\x03\x07\x04\x11\n\x0c\n\x05\x05\0\x02\0\x02\x12\x03\x07\x14\x15\
|
\x07\x04\x16\n\x0c\n\x05\x05\0\x02\0\x01\x12\x03\x07\x04\x11\n\x0c\n\x05\
|
||||||
\n\x0b\n\x04\x05\0\x02\x01\x12\x03\x08\x04\x18\n\x0c\n\x05\x05\0\x02\x01\
|
\x05\0\x02\0\x02\x12\x03\x07\x14\x15\n\x0b\n\x04\x05\0\x02\x01\x12\x03\
|
||||||
\x01\x12\x03\x08\x04\x13\n\x0c\n\x05\x05\0\x02\x01\x02\x12\x03\x08\x16\
|
\x08\x04\x18\n\x0c\n\x05\x05\0\x02\x01\x01\x12\x03\x08\x04\x13\n\x0c\n\
|
||||||
\x17\n\x0b\n\x04\x05\0\x02\x02\x12\x03\t\x04\x1b\n\x0c\n\x05\x05\0\x02\
|
\x05\x05\0\x02\x01\x02\x12\x03\x08\x16\x17\n\x0b\n\x04\x05\0\x02\x02\x12\
|
||||||
\x02\x01\x12\x03\t\x04\x16\n\x0c\n\x05\x05\0\x02\x02\x02\x12\x03\t\x19\
|
\x03\t\x04\x1b\n\x0c\n\x05\x05\0\x02\x02\x01\x12\x03\t\x04\x16\n\x0c\n\
|
||||||
\x1ab\x06proto3\
|
\x05\x05\0\x02\x02\x02\x12\x03\t\x19\x1a\n\x0b\n\x04\x05\0\x02\x03\x12\
|
||||||
|
\x03\n\x04\x15\n\x0c\n\x05\x05\0\x02\x03\x01\x12\x03\n\x04\x10\n\x0c\n\
|
||||||
|
\x05\x05\0\x02\x03\x02\x12\x03\n\x13\x14b\x06proto3\
|
||||||
";
|
";
|
||||||
|
|
||||||
static file_descriptor_proto_lazy: ::protobuf::rt::LazyV2<::protobuf::descriptor::FileDescriptorProto> = ::protobuf::rt::LazyV2::INIT;
|
static file_descriptor_proto_lazy: ::protobuf::rt::LazyV2<::protobuf::descriptor::FileDescriptorProto> = ::protobuf::rt::LazyV2::INIT;
|
||||||
|
|
|
@ -8,4 +8,5 @@ enum ErrorCode {
|
||||||
InternalError = 0;
|
InternalError = 0;
|
||||||
DuplicateSource = 1;
|
DuplicateSource = 1;
|
||||||
UnsupportedMessage = 2;
|
UnsupportedMessage = 2;
|
||||||
|
Unauthorized = 3;
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender};
|
||||||
use futures_core::{ready, Stream};
|
use futures_core::{ready, Stream};
|
||||||
|
|
||||||
use crate::connect::Retry;
|
use crate::connect::Retry;
|
||||||
|
use bytes::Buf;
|
||||||
use futures_core::future::BoxFuture;
|
use futures_core::future::BoxFuture;
|
||||||
use pin_project::pin_project;
|
use pin_project::pin_project;
|
||||||
use std::{
|
use std::{
|
||||||
|
@ -14,8 +15,15 @@ use std::{
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
};
|
};
|
||||||
use tokio::task::JoinHandle;
|
use tokio::{sync::RwLock, task::JoinHandle};
|
||||||
use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream};
|
use tokio_tungstenite::{
|
||||||
|
tungstenite::{
|
||||||
|
protocol::{frame::coding::CloseCode, CloseFrame},
|
||||||
|
Message,
|
||||||
|
},
|
||||||
|
MaybeTlsStream,
|
||||||
|
WebSocketStream,
|
||||||
|
};
|
||||||
|
|
||||||
pub type MsgReceiver = UnboundedReceiver<Message>;
|
pub type MsgReceiver = UnboundedReceiver<Message>;
|
||||||
pub type MsgSender = UnboundedSender<Message>;
|
pub type MsgSender = UnboundedSender<Message>;
|
||||||
|
@ -24,22 +32,58 @@ pub trait WsMessageHandler: Sync + Send + 'static {
|
||||||
fn receive_message(&self, msg: WsMessage);
|
fn receive_message(&self, msg: WsMessage);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type NotifyCallback = Arc<dyn Fn(&WsState) + Send + Sync + 'static>;
|
||||||
|
struct WsStateNotify {
|
||||||
|
#[allow(dead_code)]
|
||||||
|
state: WsState,
|
||||||
|
callback: Option<NotifyCallback>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WsStateNotify {
|
||||||
|
fn update_state(&mut self, state: WsState) {
|
||||||
|
if let Some(f) = &self.callback {
|
||||||
|
f(&state);
|
||||||
|
}
|
||||||
|
self.state = state;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum WsState {
|
||||||
|
Init,
|
||||||
|
Connected(Arc<WsSender>),
|
||||||
|
Disconnected(WsError),
|
||||||
|
}
|
||||||
|
|
||||||
pub struct WsController {
|
pub struct WsController {
|
||||||
sender: Option<Arc<WsSender>>,
|
|
||||||
handlers: HashMap<String, Arc<dyn WsMessageHandler>>,
|
handlers: HashMap<String, Arc<dyn WsMessageHandler>>,
|
||||||
|
state_notify: Arc<RwLock<WsStateNotify>>,
|
||||||
addr: Option<String>,
|
addr: Option<String>,
|
||||||
|
sender: Option<Arc<WsSender>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WsController {
|
impl WsController {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
|
let state_notify = Arc::new(RwLock::new(WsStateNotify {
|
||||||
|
state: WsState::Init,
|
||||||
|
callback: None,
|
||||||
|
}));
|
||||||
|
|
||||||
let controller = Self {
|
let controller = Self {
|
||||||
sender: None,
|
|
||||||
handlers: HashMap::new(),
|
handlers: HashMap::new(),
|
||||||
|
state_notify,
|
||||||
addr: None,
|
addr: None,
|
||||||
|
sender: None,
|
||||||
};
|
};
|
||||||
controller
|
controller
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn state_callback<SC>(&self, callback: SC)
|
||||||
|
where
|
||||||
|
SC: Fn(&WsState) + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
(self.state_notify.write().await).callback = Some(Arc::new(callback));
|
||||||
|
}
|
||||||
|
|
||||||
pub fn add_handler(&mut self, handler: Arc<dyn WsMessageHandler>) -> Result<(), WsError> {
|
pub fn add_handler(&mut self, handler: Arc<dyn WsMessageHandler>) -> Result<(), WsError> {
|
||||||
let source = handler.source();
|
let source = handler.source();
|
||||||
if self.handlers.contains_key(&source) {
|
if self.handlers.contains_key(&source) {
|
||||||
|
@ -61,9 +105,12 @@ impl WsController {
|
||||||
fn _connect(&mut self, addr: String, retry: Option<BoxFuture<'static, ()>>) -> Result<JoinHandle<()>, ServerError> {
|
fn _connect(&mut self, addr: String, retry: Option<BoxFuture<'static, ()>>) -> Result<JoinHandle<()>, ServerError> {
|
||||||
log::debug!("🐴 ws connect: {}", &addr);
|
log::debug!("🐴 ws connect: {}", &addr);
|
||||||
let (connection, handlers) = self.make_connect(addr.clone());
|
let (connection, handlers) = self.make_connect(addr.clone());
|
||||||
|
let state_notify = self.state_notify.clone();
|
||||||
|
let sender = self.sender.clone().expect("Sender should be not empty after calling make_connect");
|
||||||
Ok(tokio::spawn(async move {
|
Ok(tokio::spawn(async move {
|
||||||
match connection.await {
|
match connection.await {
|
||||||
Ok(stream) => {
|
Ok(stream) => {
|
||||||
|
state_notify.write().await.update_state(WsState::Connected(sender));
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
result = stream => {
|
result = stream => {
|
||||||
match result {
|
match result {
|
||||||
|
@ -71,17 +118,19 @@ impl WsController {
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
// TODO: retry?
|
// TODO: retry?
|
||||||
log::error!("ws stream error {:?}", e);
|
log::error!("ws stream error {:?}", e);
|
||||||
|
state_notify.write().await.update_state(WsState::Disconnected(e));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
result = handlers => log::debug!("handlers completed {:?}", result),
|
result = handlers => log::debug!("handlers completed {:?}", result),
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
Err(e) => match retry {
|
Err(e) => {
|
||||||
None => log::error!("ws connect {} failed {:?}", addr, e),
|
log::error!("ws connect {} failed {:?}", addr, e);
|
||||||
Some(retry) => {
|
state_notify.write().await.update_state(WsState::Disconnected(e));
|
||||||
|
if let Some(retry) = retry {
|
||||||
tokio::spawn(retry);
|
tokio::spawn(retry);
|
||||||
},
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
|
@ -101,17 +150,10 @@ impl WsController {
|
||||||
let (msg_tx, msg_rx) = futures_channel::mpsc::unbounded();
|
let (msg_tx, msg_rx) = futures_channel::mpsc::unbounded();
|
||||||
let (ws_tx, ws_rx) = futures_channel::mpsc::unbounded();
|
let (ws_tx, ws_rx) = futures_channel::mpsc::unbounded();
|
||||||
let handlers = self.handlers.clone();
|
let handlers = self.handlers.clone();
|
||||||
self.sender = Some(Arc::new(WsSender::new(ws_tx)));
|
self.sender = Some(Arc::new(WsSender { ws_tx }));
|
||||||
self.addr = Some(addr.clone());
|
self.addr = Some(addr.clone());
|
||||||
(WsConnection::new(msg_tx, ws_rx, addr), WsHandlers::new(handlers, msg_rx))
|
(WsConnection::new(msg_tx, ws_rx, addr), WsHandlers::new(handlers, msg_rx))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn send_msg<T: Into<WsMessage>>(&self, msg: T) -> Result<(), WsError> {
|
|
||||||
match self.sender.as_ref() {
|
|
||||||
None => Err(WsError::internal().context("Should call make_connect first")),
|
|
||||||
Some(sender) => sender.send(msg.into()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pin_project]
|
#[pin_project]
|
||||||
|
@ -146,17 +188,55 @@ impl Future for WsHandlers {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct WsSender {
|
// impl WsSender for WsController {
|
||||||
|
// fn send_msg(&self, msg: WsMessage) -> Result<(), WsError> {
|
||||||
|
// match self.ws_tx.as_ref() {
|
||||||
|
// None => Err(WsError::internal().context("Should call make_connect
|
||||||
|
// first")), Some(sender) => {
|
||||||
|
// let _ = sender.unbounded_send(msg.into()).map_err(|e|
|
||||||
|
// WsError::internal().context(e))?; Ok(())
|
||||||
|
// },
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct WsSender {
|
||||||
ws_tx: MsgSender,
|
ws_tx: MsgSender,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WsSender {
|
impl WsSender {
|
||||||
pub fn new(ws_tx: MsgSender) -> Self { Self { ws_tx } }
|
pub fn send_msg<T: Into<WsMessage>>(&self, msg: T) -> Result<(), WsError> {
|
||||||
|
let msg = msg.into();
|
||||||
pub fn send(&self, msg: WsMessage) -> Result<(), WsError> {
|
|
||||||
let _ = self.ws_tx.unbounded_send(msg.into()).map_err(|e| WsError::internal().context(e))?;
|
let _ = self.ws_tx.unbounded_send(msg.into()).map_err(|e| WsError::internal().context(e))?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn send_text(&self, source: &str, text: &str) -> Result<(), WsError> {
|
||||||
|
let msg = WsMessage {
|
||||||
|
source: source.to_string(),
|
||||||
|
data: text.as_bytes().to_vec(),
|
||||||
|
};
|
||||||
|
self.send_msg(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn send_binary(&self, source: &str, bytes: Vec<u8>) -> Result<(), WsError> {
|
||||||
|
let msg = WsMessage {
|
||||||
|
source: source.to_string(),
|
||||||
|
data: bytes,
|
||||||
|
};
|
||||||
|
self.send_msg(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn send_disconnect(&self, reason: &str) -> Result<(), WsError> {
|
||||||
|
let frame = CloseFrame {
|
||||||
|
code: CloseCode::Normal,
|
||||||
|
reason: reason.to_owned().into(),
|
||||||
|
};
|
||||||
|
let msg = Message::Close(Some(frame));
|
||||||
|
let _ = self.ws_tx.unbounded_send(msg).map_err(|e| WsError::internal().context(e))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue