diff --git a/src/websocket.rs b/src/websocket.rs index c5ee839..a40b719 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -14,6 +14,7 @@ use std::{ time::Duration, }; use tokio::sync::Mutex; +use tokio::time::Instant; use tokio::{net::TcpStream, time::timeout}; use tokio_tungstenite::{ connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream, WebSocketStream, @@ -34,10 +35,9 @@ pub struct WsFramedStream { // read_buf: BytesMut, } +const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(3); +const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(10); impl WsFramedStream { - const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); - const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(15); - pub async fn new>( url: T, local_addr: Option, @@ -121,15 +121,26 @@ impl WsFramedStream { fn start_heartbeat(&self) { let writer = Arc::clone(&self.writer); tokio::spawn(async move { - let mut interval = tokio::time::interval(Self::HEARTBEAT_INTERVAL); + let mut last_pong = Instant::now(); + let mut interval = tokio::time::interval(HEARTBEAT_INTERVAL); + loop { - interval.tick().await; - let mut lock = writer.lock().await; - if let Err(e) = lock.send(WsMessage::Ping(Bytes::new())).await { - log::error!("Failed to send ping: {}", e); - break; + tokio::select! { + _ = interval.tick() => { + let mut lock = writer.lock().await; + if let Err(e) = lock.send(WsMessage::Ping(Bytes::new())).await { + log::error!("Heartbeat failed: {}", e); + break; + } + log::debug!("Sent ping"); + } + _ = tokio::time::sleep(HEARTBEAT_TIMEOUT) => { + if last_pong.elapsed() > HEARTBEAT_TIMEOUT { + log::error!("Heartbeat timeout"); + break; + } + } } - drop(lock); // 及时释放锁 } }); } @@ -271,7 +282,7 @@ impl WsFramedStream { } } - if start.elapsed() > Self::HEARTBEAT_TIMEOUT { + if start.elapsed() > HEARTBEAT_TIMEOUT { log::warn!("No message received within heartbeat timeout"); return Some(Err(Error::new(ErrorKind::TimedOut, "Heartbeat timeout"))); }