diff --git a/src/websocket.rs b/src/websocket.rs index 73fe102..c5ee839 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -5,12 +5,15 @@ use crate::{ ResultType, }; use bytes::{BufMut, Bytes, BytesMut}; +use futures::stream::SplitSink; use futures::{SinkExt, StreamExt}; +use std::sync::Arc; use std::{ io::{Error, ErrorKind}, net::SocketAddr, time::Duration, }; +use tokio::sync::Mutex; use tokio::{net::TcpStream, time::timeout}; use tokio_tungstenite::{ connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream, WebSocketStream, @@ -22,7 +25,9 @@ use tungstenite::protocol::Role; pub struct Encrypt(Key, u64, u64); pub struct WsFramedStream { - stream: WebSocketStream>, + // stream: WebSocketStream>, + writer: Arc>, WsMessage>>>, + reader: futures::stream::SplitStream>>, addr: SocketAddr, encrypt: Option, send_timeout: u64, @@ -30,6 +35,9 @@ pub struct WsFramedStream { } impl WsFramedStream { + const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); + const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(15); + pub async fn new>( url: T, local_addr: Option, @@ -63,16 +71,22 @@ impl WsFramedStream { _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), }; - Ok(Self { - stream: ws_stream, + let (writer, reader) = ws_stream.split(); + + let mut ws = Self { + writer: Arc::new(Mutex::new(writer)), + reader, addr, encrypt: None, send_timeout: ms_timeout, - }) + }; + + ws.start_heartbeat(); + Ok(ws) } else { log::info!("{:?}", url_str); - let mut request = url_str + let request = url_str .into_client_request() .map_err(|e| Error::new(ErrorKind::Other, e))?; @@ -90,15 +104,36 @@ impl WsFramedStream { _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), }; - Ok(Self { - stream, + let (writer, reader) = stream.split(); + let mut ws = Self { + writer: Arc::new(Mutex::new(writer)), + reader, addr, encrypt: None, send_timeout: ms_timeout, - }) + }; + + ws.start_heartbeat(); + Ok(ws) } } + fn start_heartbeat(&self) { + let writer = Arc::clone(&self.writer); + tokio::spawn(async move { + let mut interval = tokio::time::interval(Self::HEARTBEAT_INTERVAL); + loop { + interval.tick().await; + let mut lock = writer.lock().await; + if let Err(e) = lock.send(WsMessage::Ping(Bytes::new())).await { + log::error!("Failed to send ping: {}", e); + break; + } + drop(lock); // 及时释放锁 + } + }); + } + pub fn set_raw(&mut self) {} pub async fn from_tcp_stream(stream: TcpStream, addr: SocketAddr) -> ResultType { @@ -106,12 +141,14 @@ impl WsFramedStream { WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Client, None) .await; + let (writer, reader) = ws_stream.split(); + Ok(Self { - stream: ws_stream, + writer: Arc::new(Mutex::new(writer)), + reader, addr, encrypt: None, send_timeout: 0, - // read_buf: BytesMut::new(), }) } @@ -120,12 +157,14 @@ impl WsFramedStream { WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Client, None) .await; + let (writer, reader) = ws_stream.split(); + Self { - stream: ws_stream, + writer: Arc::new(Mutex::new(writer)), + reader, addr, encrypt: None, send_timeout: 0, - // read_buf: BytesMut::new(), } } @@ -157,52 +196,84 @@ impl WsFramedStream { #[inline] pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> { - let msg = WsMessage::Binary(Bytes::from(bytes)); + let msg = WsMessage::Binary(bytes); + let mut writer = self.writer.lock().await; if self.send_timeout > 0 { - let send_future = self.stream.send(msg); - timeout(Duration::from_millis(self.send_timeout), send_future) - .await - .map_err(|_| Error::new(ErrorKind::TimedOut, "Send timeout"))? - .map_err(|e| Error::new(ErrorKind::Other, e.to_string()))?; + timeout(Duration::from_millis(self.send_timeout), writer.send(msg)).await?? } else { - self.stream - .send(msg) - .await - .map_err(|e| Error::new(ErrorKind::Other, e.to_string()))?; - } + writer.send(msg).await? + }; Ok(()) } #[inline] pub async fn next(&mut self) -> Option> { - log::info!("test"); + log::debug!("Waiting for next message"); + let start = std::time::Instant::now(); + loop { - match self.stream.next().await? { - Ok(WsMessage::Binary(data)) => { - let mut bytes = BytesMut::from(&data[..]); - if let Some(key) = self.encrypt.as_mut() { - if let Err(e) = key.dec(&mut bytes) { - return Some(Err(e)); + match self.reader.next().await { + Some(Ok(msg)) => { + log::debug!("Received message: {:?}", &msg); + match msg { + WsMessage::Binary(data) => { + log::info!("Received binary data ({} bytes)", data.len()); + let mut bytes = BytesMut::from(&data[..]); + if let Some(key) = self.encrypt.as_mut() { + log::debug!("Decrypting data with seq: {}", key.2); + match key.dec(&mut bytes) { + Ok(_) => { + log::debug!("Decryption successful"); + return Some(Ok(bytes)); + } + Err(e) => { + log::error!("Decryption failed: {}", e); + return Some(Err(e)); + } + } + } + return Some(Ok(bytes)); + } + WsMessage::Ping(ping) => { + log::info!("Received ping ({} bytes)", ping.len()); + let mut writer = self.writer.lock().await; + if let Err(e) = writer.send(WsMessage::Pong(ping)).await { + log::error!("Failed to send pong: {}", e); + return Some(Err(Error::new( + ErrorKind::Other, + format!("Failed to send pong: {}", e), + ))); + } + log::debug!("Pong sent"); + } + WsMessage::Pong(_) => { + log::debug!("Received pong"); + } + WsMessage::Close(frame) => { + log::info!("Connection closed: {:?}", frame); + return None; + } + _ => { + log::warn!("Unhandled message :{}", &msg); } } - return Some(Ok(bytes)); } - Ok(WsMessage::Ping(ping)) => { - if let Err(e) = self.stream.send(WsMessage::Pong(ping)).await { - return Some(Err(Error::new( - ErrorKind::Other, - format!("Failed to send pong: {}", e), - ))); - } - continue; + Some(Err(e)) => { + log::error!("WebSocket error: {}", e); + return Some(Err(Error::new( + ErrorKind::Other, + format!("Failed to send pong: {}", e), + ))); } - Ok(WsMessage::Pong(_)) => { - log::debug!("Received pong"); - continue; + None => { + log::info!("Connection closed gracefully"); + return None; } - Ok(WsMessage::Close(_)) => return None, - Ok(_) => continue, - Err(e) => return Some(Err(Error::new(ErrorKind::Other, e))), + } + + if start.elapsed() > Self::HEARTBEAT_TIMEOUT { + log::warn!("No message received within heartbeat timeout"); + return Some(Err(Error::new(ErrorKind::TimedOut, "Heartbeat timeout"))); } } }