rusty_common/websocket/
heartbeat.rs

1//! Heartbeat monitoring for WebSocket connections
2//!
3//! This module provides a heartbeat mechanism for monitoring the health of WebSocket connections
4//! and triggering reconnection when necessary.
5
6use log::{debug, warn};
7use parking_lot::RwLock;
8use quanta::{Clock, Instant};
9use std::sync::Arc;
10use std::time::Duration;
11
12use super::client::ConnectionState;
13use super::stats::ConnectionStats;
14
15/// Heartbeat status
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum HeartbeatStatus {
18    /// Connection is healthy
19    Healthy,
20    /// Connection is unhealthy (missed some heartbeats)
21    Unhealthy,
22    /// Connection is dead (missed too many heartbeats)
23    Dead,
24}
25
26/// Heartbeat monitor for WebSocket connections
27#[derive(Debug, Clone)]
28pub struct HeartbeatMonitor {
29    /// Last time a message was received
30    last_message_time: Arc<RwLock<Instant>>,
31    /// Last time a heartbeat check was performed
32    last_check_time: Arc<RwLock<Instant>>,
33    /// Heartbeat interval in milliseconds
34    heartbeat_interval_milliseconds: u64,
35    /// Heartbeat timeout in milliseconds
36    heartbeat_timeout_milliseconds: u64,
37    /// Maximum number of consecutive missed heartbeats before considering the connection dead
38    max_missed_heartbeats: u32,
39    /// Number of consecutive missed heartbeats
40    missed_heartbeats: Arc<RwLock<u32>>,
41    /// Connection status
42    connection_status: Arc<RwLock<ConnectionState>>,
43    /// Connection statistics
44    stats: Arc<RwLock<ConnectionStats>>,
45    /// High-precision clock
46    clock: Clock,
47}
48
49impl HeartbeatMonitor {
50    /// Create a new heartbeat monitor
51    #[must_use]
52    pub fn new(
53        heartbeat_interval_milliseconds: u64,
54        heartbeat_timeout_milliseconds: u64,
55        max_missed_heartbeats: u32,
56        connection_status: Arc<RwLock<ConnectionState>>,
57        stats: Arc<RwLock<ConnectionStats>>,
58    ) -> Self {
59        let clock = Clock::new();
60        let now = clock.now();
61        Self {
62            last_message_time: Arc::new(RwLock::new(now)),
63            last_check_time: Arc::new(RwLock::new(now)),
64            heartbeat_interval_milliseconds,
65            heartbeat_timeout_milliseconds,
66            max_missed_heartbeats,
67            missed_heartbeats: Arc::new(RwLock::new(0)),
68            connection_status,
69            stats,
70            clock,
71        }
72    }
73
74    /// Update the last message time
75    pub fn update_last_message_time(&self) {
76        *self.last_message_time.write() = self.clock.now();
77        *self.missed_heartbeats.write() = 0;
78    }
79
80    /// Check the heartbeat status
81    pub fn check_heartbeat(&self) -> HeartbeatStatus {
82        // If heartbeat is disabled, always return Healthy
83        if self.heartbeat_interval_milliseconds == 0 {
84            return HeartbeatStatus::Healthy;
85        }
86
87        let now = self.clock.now();
88        let last_message_time = *self.last_message_time.read();
89        let last_check_time = *self.last_check_time.read();
90        let elapsed_since_last_message = now.duration_since(last_message_time);
91        let elapsed_since_last_check = now.duration_since(last_check_time);
92        let interval = Duration::from_millis(self.heartbeat_interval_milliseconds);
93        let timeout = Duration::from_millis(self.heartbeat_timeout_milliseconds);
94
95        // Only check heartbeat if enough time has passed since the last check
96        if elapsed_since_last_check < interval {
97            // If we haven't reached the heartbeat interval yet, return the current status
98            let missed = *self.missed_heartbeats.read();
99            return if missed >= self.max_missed_heartbeats {
100                HeartbeatStatus::Dead
101            } else if missed > 0 {
102                HeartbeatStatus::Unhealthy
103            } else {
104                HeartbeatStatus::Healthy
105            };
106        }
107
108        // Update last check time
109        *self.last_check_time.write() = now;
110
111        // Check if we've exceeded the heartbeat timeout
112        if elapsed_since_last_message > timeout {
113            // Increment missed heartbeats
114            let mut missed_heartbeats = self.missed_heartbeats.write();
115            *missed_heartbeats += 1;
116
117            // Log the missed heartbeat
118            debug!(
119                "Missed heartbeat: {} of {} (timeout: {}ms, elapsed: {}ms)",
120                *missed_heartbeats,
121                self.max_missed_heartbeats,
122                self.heartbeat_timeout_milliseconds,
123                elapsed_since_last_message.as_millis()
124            );
125
126            // Check if we've missed too many heartbeats
127            if *missed_heartbeats >= self.max_missed_heartbeats {
128                warn!(
129                    "Connection dead: missed {} heartbeats (max: {})",
130                    *missed_heartbeats, self.max_missed_heartbeats
131                );
132                *self.connection_status.write() = ConnectionState::Error;
133                self.stats.write().errors += 1;
134                return HeartbeatStatus::Dead;
135            }
136
137            // Connection is unhealthy but not dead yet
138            return HeartbeatStatus::Unhealthy;
139        }
140
141        // Connection is healthy
142        *self.missed_heartbeats.write() = 0;
143        HeartbeatStatus::Healthy
144    }
145
146    /// Reset the heartbeat monitor
147    pub fn reset(&self) {
148        let now = self.clock.now();
149        *self.last_message_time.write() = now;
150        *self.last_check_time.write() = now;
151        *self.missed_heartbeats.write() = 0;
152    }
153
154    /// Get the number of consecutive missed heartbeats
155    pub fn missed_heartbeats(&self) -> u32 {
156        *self.missed_heartbeats.read()
157    }
158
159    /// Get the time elapsed since the last message
160    pub fn time_since_last_message(&self) -> Duration {
161        self.clock
162            .now()
163            .duration_since(*self.last_message_time.read())
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use std::thread;
171
172    #[test]
173    fn test_heartbeat_monitor() {
174        let connection_status = Arc::new(RwLock::new(ConnectionState::Connected));
175        let stats = Arc::new(RwLock::new(ConnectionStats::default()));
176
177        // Create a heartbeat monitor with a short timeout for testing
178        let monitor = HeartbeatMonitor::new(
179            100, // 100ms heartbeat interval
180            200, // 200ms heartbeat timeout
181            3,   // 3 missed heartbeats before connection is dead
182            connection_status.clone(),
183            stats.clone(),
184        );
185
186        // Initially, the connection should be healthy
187        assert_eq!(monitor.check_heartbeat(), HeartbeatStatus::Healthy);
188
189        // Update the last message time
190        monitor.update_last_message_time();
191        assert_eq!(monitor.check_heartbeat(), HeartbeatStatus::Healthy);
192
193        // Wait for the heartbeat timeout to expire
194        thread::sleep(Duration::from_millis(300));
195
196        // After the timeout, the connection should be unhealthy
197        assert_eq!(monitor.check_heartbeat(), HeartbeatStatus::Unhealthy);
198        assert_eq!(monitor.missed_heartbeats(), 1);
199
200        // Wait for more timeouts
201        thread::sleep(Duration::from_millis(300));
202        assert_eq!(monitor.check_heartbeat(), HeartbeatStatus::Unhealthy);
203        assert_eq!(monitor.missed_heartbeats(), 2);
204
205        thread::sleep(Duration::from_millis(300));
206
207        // After 3 missed heartbeats, the connection should be dead
208        assert_eq!(monitor.check_heartbeat(), HeartbeatStatus::Dead);
209        assert_eq!(monitor.missed_heartbeats(), 3);
210        assert_eq!(*connection_status.read(), ConnectionState::Error);
211
212        // Reset the monitor
213        monitor.reset();
214        assert_eq!(monitor.missed_heartbeats(), 0);
215        assert_eq!(monitor.check_heartbeat(), HeartbeatStatus::Healthy);
216    }
217}