diff --git a/rust-lib/flowy-ws/src/connect.rs b/rust-lib/flowy-ws/src/connect.rs index 6c388278e1..065c17a889 100644 --- a/rust-lib/flowy-ws/src/connect.rs +++ b/rust-lib/flowy-ws/src/connect.rs @@ -1,23 +1,16 @@ -use crate::{errors::WsError, MsgReceiver, MsgSender, WsMessage}; +use crate::{errors::WsError, MsgReceiver, MsgSender}; use flowy_net::errors::ServerError; -use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender}; -use futures_core::{future::BoxFuture, ready, Stream}; -use futures_util::{ - future, - future::{Either, Select}, - pin_mut, - FutureExt, - StreamExt, -}; +use futures_core::{future::BoxFuture, ready}; +use futures_util::{FutureExt, StreamExt, TryStreamExt}; use pin_project::pin_project; use std::{ - collections::HashMap, + fmt, future::Future, pin::Pin, sync::Arc, task::{Context, Poll}, }; -use tokio::{net::TcpStream, task::JoinHandle}; +use tokio::net::TcpStream; use tokio_tungstenite::{ connect_async, tungstenite::{handshake::client::Response, http::StatusCode, Error, Message}, @@ -63,124 +56,70 @@ impl Future for WsConnection { loop { return match ready!(self.as_mut().project().fut.poll(cx)) { Ok((stream, _)) => { - log::debug!("🐴 ws connect success"); + log::debug!("🐴 ws connect success: {:?}", error); let (msg_tx, ws_rx) = (self.msg_tx.take().unwrap(), self.ws_rx.take().unwrap()); Poll::Ready(Ok(WsStream::new(msg_tx, ws_rx, stream))) }, - Err(error) => Poll::Ready(Err(error_to_flowy_response(error))), + Err(error) => { + log::debug!("🐴 ws connect failed: {:?}", error); + Poll::Ready(Err(error_to_flowy_response(error))) + }, }; } } } +type Fut = BoxFuture<'static, Result<(), WsError>>; #[pin_project] pub struct WsStream { - msg_tx: MsgSender, #[pin] - fut: Option<(BoxFuture<'static, ()>, BoxFuture<'static, ()>)>, + inner: Option<(Fut, Fut)>, } impl WsStream { pub fn new(msg_tx: MsgSender, ws_rx: MsgReceiver, stream: WebSocketStream>) -> Self { let (ws_write, ws_read) = stream.split(); - let to_ws = ws_rx.map(Ok).forward(ws_write); - let from_ws = ws_read.for_each(|message| async { - // handle_new_message(msg_tx.clone(), message) - }); - // pin_mut!(to_ws, from_ws); Self { - msg_tx, - fut: Some(( + inner: Some(( Box::pin(async move { - let _ = from_ws.await; + let _ = ws_read.for_each(|message| async { post_message(msg_tx.clone(), message) }).await; + Ok(()) }), Box::pin(async move { - let _ = to_ws.await; + let _ = ws_rx.map(Ok).forward(ws_write).await?; + Ok(()) }), )), } } } +impl fmt::Debug for WsStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("WsStream").finish() } +} + impl Future for WsStream { - type Output = (); + type Output = Result<(), WsError>; + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let (mut a, mut b) = self.fut.take().unwrap(); - match a.poll_unpin(cx) { - Poll::Ready(x) => Poll::Ready(()), - Poll::Pending => match b.poll_unpin(cx) { - Poll::Ready(x) => Poll::Ready(()), - Poll::Pending => { - // self.fut = Some((a, b)); - Poll::Pending - }, + let (mut left, mut right) = self.inner.take().unwrap(); + match left.poll_unpin(cx) { + Poll::Ready(l) => Poll::Ready(l), + Poll::Pending => { + // + match right.poll_unpin(cx) { + Poll::Ready(r) => Poll::Ready(r), + Poll::Pending => { + self.inner = Some((left, right)); + Poll::Pending + }, + } }, } } } -// pub struct WsStream { -// msg_tx: Option, -// ws_rx: Option, -// stream: Option>>, -// } -// -// impl WsStream { -// pub fn new(msg_tx: MsgSender, ws_rx: MsgReceiver, stream: -// WebSocketStream>) -> Self { Self { -// msg_tx: Some(msg_tx), -// ws_rx: Some(ws_rx), -// stream: Some(stream), -// } -// } -// -// pub fn start(mut self) -> JoinHandle<()> { -// let (msg_tx, ws_rx) = (self.msg_tx.take().unwrap(), -// self.ws_rx.take().unwrap()); let (ws_write, ws_read) = -// self.stream.take().unwrap().split(); tokio::spawn(async move { -// let to_ws = ws_rx.map(Ok).forward(ws_write); -// let from_ws = ws_read.for_each(|message| async { -// handle_new_message(msg_tx.clone(), message) }); pin_mut!(to_ws, -// from_ws); -// -// match future::select(to_ws, from_ws).await { -// Either::Left(_l) => { -// log::info!("ws left"); -// }, -// Either::Right(_r) => { -// log::info!("ws right"); -// }, -// } -// }) -// } -// } -// -// impl Future for WsStream { -// type Output = (); -// fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> -// Poll { let (msg_tx, ws_rx) = -// (self.msg_tx.take().unwrap(), self.ws_rx.take().unwrap()); let -// (ws_write, ws_read) = self.stream.take().unwrap().split(); let to_ws -// = ws_rx.map(Ok).forward(ws_write); let from_ws = -// ws_read.for_each(|message| async { handle_new_message(msg_tx.clone(), -// message) }); pin_mut!(to_ws, from_ws); -// -// loop { -// match ready!(Pin::new(&mut future::select(to_ws, -// from_ws)).poll(cx)) { Either::Left(a) => { -// // -// return Poll::Ready(()); -// }, -// Either::Right(b) => { -// // -// return Poll::Ready(()); -// }, -// } -// } -// } -// } - -fn handle_new_message(tx: MsgSender, message: Result) { +fn post_message(tx: MsgSender, message: Result) { match message { Ok(Message::Binary(bytes)) => match tx.unbounded_send(Message::Binary(bytes)) { Ok(_) => {}, diff --git a/rust-lib/flowy-ws/src/errors.rs b/rust-lib/flowy-ws/src/errors.rs index 557d7ccbca..3fdf0b3b72 100644 --- a/rust-lib/flowy-ws/src/errors.rs +++ b/rust-lib/flowy-ws/src/errors.rs @@ -62,3 +62,7 @@ impl std::convert::From for WsError { impl std::convert::From> for WsError { fn from(error: TrySendError) -> Self { WsError::internal().context(error) } } + +impl std::convert::From for WsError { + fn from(error: tokio_tungstenite::tungstenite::Error) -> Self { WsError::internal().context(error) } +} diff --git a/rust-lib/flowy-ws/src/ws.rs b/rust-lib/flowy-ws/src/ws.rs index 2445789886..59edac514d 100644 --- a/rust-lib/flowy-ws/src/ws.rs +++ b/rust-lib/flowy-ws/src/ws.rs @@ -1,14 +1,8 @@ use crate::{connect::WsConnection, errors::WsError, WsMessage}; use flowy_net::errors::ServerError; use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender}; -use futures_core::{future::BoxFuture, ready, Stream}; -use futures_util::{ - future, - future::{Either, Select}, - pin_mut, - FutureExt, - StreamExt, -}; +use futures_core::{ready, Stream}; + use pin_project::pin_project; use std::{ collections::HashMap, @@ -17,13 +11,8 @@ use std::{ sync::Arc, task::{Context, Poll}, }; -use tokio::{net::TcpStream, task::JoinHandle}; -use tokio_tungstenite::{ - connect_async, - tungstenite::{handshake::client::Response, http::StatusCode, Error, Message}, - MaybeTlsStream, - WebSocketStream, -}; +use tokio::task::JoinHandle; +use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream}; pub type MsgReceiver = UnboundedReceiver; pub type MsgSender = UnboundedSender; @@ -56,7 +45,7 @@ impl WsController { } pub fn connect(&mut self, addr: String) -> Result, ServerError> { - log::debug!("🐴 Try to connect: {}", &addr); + log::debug!("🐴 ws connect: {}", &addr); let (connection, handlers) = self.make_connect(addr); Ok(tokio::spawn(async { tokio::select! {