rusty_feeder/limit/
lockfree_rate_limiter.rs

1use std::{
2    fmt,
3    num::NonZeroUsize,
4    sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
5    time::Duration,
6};
7
8use parking_lot::Mutex;
9use quanta::Clock;
10use smallvec::SmallVec;
11
12// Note: helpers imported but not used in this implementation
13// use super::helpers::*;
14
15/// Ultra-low-latency rate limiter designed for HFT applications
16/// Uses lock-free algorithms to eliminate async/await overhead in critical paths
17#[repr(align(64))] // Cache-line alignment for better CPU cache efficiency
18pub struct LockFreeRateLimiter {
19    /// Maximum number of tokens in bucket
20    max_tokens: usize,
21
22    /// Current available tokens, atomic for lock-free access
23    available_tokens: AtomicUsize,
24
25    /// Rate limit window in nanoseconds
26    window_ns: u64,
27
28    /// Last refill timestamp in nanoseconds since epoch
29    last_refill: AtomicU64,
30
31    /// Whether this rate limiter is currently being refilled (prevents contention)
32    refilling: AtomicBool,
33
34    /// Recent timestamp history for adaptive rate limiting (protected by mutex)
35    recent_timestamps: Mutex<SmallVec<[u64; 64]>>,
36}
37
38/// Token guard that provides RAII protection for rate limiting
39#[derive(Debug)]
40pub struct TokenGuard<'a> {
41    limiter: &'a LockFreeRateLimiter,
42    acquired: bool,
43}
44
45impl<'a> TokenGuard<'a> {
46    const fn new(limiter: &'a LockFreeRateLimiter, acquired: bool) -> Self {
47        Self { limiter, acquired }
48    }
49
50    /// Check if a token was successfully acquired
51    #[inline]
52    pub const fn is_acquired(&self) -> bool {
53        self.acquired
54    }
55
56    /// Manually release the token before the guard is dropped
57    #[inline]
58    pub fn release(mut self) {
59        if self.acquired {
60            self.limiter.release_token();
61            self.acquired = false;
62        }
63    }
64}
65
66impl Drop for TokenGuard<'_> {
67    fn drop(&mut self) {
68        if self.acquired {
69            self.limiter.release_token();
70        }
71    }
72}
73
74impl LockFreeRateLimiter {
75    /// Create a new lock-free rate limiter
76    #[must_use]
77    pub fn new(max_tokens: NonZeroUsize, window: Duration) -> Self {
78        static CLOCK: std::sync::OnceLock<Clock> = std::sync::OnceLock::new();
79        let clock = CLOCK.get_or_init(Clock::new);
80        let now = clock.raw();
81
82        Self {
83            max_tokens: max_tokens.get(),
84            available_tokens: AtomicUsize::new(max_tokens.get()),
85            window_ns: window.as_nanos() as u64,
86            last_refill: AtomicU64::new(now),
87            refilling: AtomicBool::new(false),
88            recent_timestamps: Mutex::new(SmallVec::new()),
89        }
90    }
91
92    /// Try to acquire a token without waiting (lock-free fast path)
93    #[inline(always)]
94    pub fn try_acquire(&self) -> TokenGuard<'_> {
95        // Fast path: try to decrement available tokens atomically
96        let mut current = self.available_tokens.load(Ordering::Acquire);
97
98        while current > 0 {
99            match self.available_tokens.compare_exchange_weak(
100                current,
101                current - 1,
102                Ordering::AcqRel,
103                Ordering::Relaxed,
104            ) {
105                Ok(_) => {
106                    // Token acquired - record timestamp for adaptive rate limiting
107                    // Get a singleton clock for efficiency
108                    static CLOCK: std::sync::OnceLock<Clock> = std::sync::OnceLock::new();
109                    let clock = CLOCK.get_or_init(Clock::new);
110                    let now = clock.raw();
111                    if let Some(mut recent) = self.recent_timestamps.try_lock() {
112                        recent.push(now);
113
114                        // Keep the buffer at a reasonable size
115                        if recent.len() > 128 {
116                            recent.drain(0..64);
117                        }
118                    }
119
120                    // Opportunistically try to refill if we're close to running out of tokens
121                    if current <= self.max_tokens / 4 {
122                        self.try_refill_tokens();
123                    }
124
125                    return TokenGuard::new(self, true);
126                }
127                Err(actual) => {
128                    // Update our view of the current count and try again
129                    current = actual;
130                }
131            }
132        }
133
134        // No tokens available - try to refill and try once more
135        self.try_refill_tokens();
136
137        current = self.available_tokens.load(Ordering::Acquire);
138        if current > 0
139            && self
140                .available_tokens
141                .compare_exchange_weak(current, current - 1, Ordering::AcqRel, Ordering::Relaxed)
142                .is_ok()
143        {
144            static CLOCK: std::sync::OnceLock<Clock> = std::sync::OnceLock::new();
145            let clock = CLOCK.get_or_init(Clock::new);
146            let now = clock.raw();
147            if let Some(mut recent) = self.recent_timestamps.try_lock() {
148                recent.push(now);
149            }
150            return TokenGuard::new(self, true);
151        }
152
153        // Failed to acquire token
154        TokenGuard::new(self, false)
155    }
156
157    /// Wait until a token becomes available (may block briefly), returns RAII guard
158    #[inline]
159    pub fn acquire(&self) -> TokenGuard<'_> {
160        // First try the fast path
161        let guard = self.try_acquire();
162        if guard.is_acquired() {
163            return guard;
164        }
165
166        // Slow path: wait for tokens to become available
167        loop {
168            // Try to refill tokens
169            self.refill_tokens();
170
171            // Try to acquire a token
172            let guard = self.try_acquire();
173            if guard.is_acquired() {
174                return guard;
175            }
176
177            // Calculate wait time until next token
178            let wait_time = self.estimate_wait_time();
179
180            // Short sleep before retrying
181            if wait_time > 0 {
182                std::thread::sleep(Duration::from_nanos(wait_time));
183            } else {
184                // Yield to other threads if no specific wait time
185                std::thread::yield_now();
186            }
187        }
188    }
189
190    /// Release a token back to the pool (used internally by TokenGuard)
191    #[inline]
192    fn release_token(&self) {
193        // No need to track returns - we'll just increment the counter
194        self.available_tokens.fetch_add(1, Ordering::Release);
195    }
196
197    /// Try to refill tokens without blocking if too much time has elapsed
198    #[inline]
199    fn try_refill_tokens(&self) {
200        // Only one thread should perform refill at a time
201        if self
202            .refilling
203            .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
204            .is_err()
205        {
206            return; // Another thread is already refilling
207        }
208
209        // Get the current timestamp
210        static CLOCK: std::sync::OnceLock<Clock> = std::sync::OnceLock::new();
211        let clock = CLOCK.get_or_init(Clock::new);
212        let now = clock.raw();
213        let last = self.last_refill.load(Ordering::Relaxed);
214
215        // Check if enough time has elapsed
216        let elapsed_ns = now.saturating_sub(last);
217        if elapsed_ns >= self.window_ns / 10 {
218            // At least 10% of the window has passed
219            // Calculate tokens to add
220            let refill_amount =
221                ((elapsed_ns as f64 / self.window_ns as f64) * self.max_tokens as f64) as usize;
222
223            if refill_amount > 0 {
224                // Add tokens up to max
225                let current = self.available_tokens.load(Ordering::Relaxed);
226                let new_count = std::cmp::min(current + refill_amount, self.max_tokens);
227
228                if new_count > current {
229                    self.available_tokens.store(new_count, Ordering::Release);
230
231                    // Update last refill time proportionally to tokens added
232                    let new_last =
233                        last + (elapsed_ns * refill_amount as u64 / self.max_tokens as u64);
234                    self.last_refill.store(new_last, Ordering::Release);
235                }
236            }
237        }
238
239        // Release the refill lock
240        self.refilling.store(false, Ordering::Release);
241    }
242
243    /// Force a refill of tokens based on elapsed time
244    #[inline]
245    fn refill_tokens(&self) {
246        // Similar to try_refill but will always attempt to refill tokens
247        // Only one thread should perform refill at a time
248        if self
249            .refilling
250            .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
251            .is_err()
252        {
253            return; // Another thread is already refilling
254        }
255
256        // Get the current timestamp
257        static CLOCK: std::sync::OnceLock<Clock> = std::sync::OnceLock::new();
258        let clock = CLOCK.get_or_init(Clock::new);
259        let now = clock.raw();
260        let last = self.last_refill.load(Ordering::Relaxed);
261
262        // Calculate elapsed time
263        let elapsed_ns = now.saturating_sub(last);
264
265        // If full window has elapsed, do a complete refill
266        if elapsed_ns >= self.window_ns {
267            self.available_tokens
268                .store(self.max_tokens, Ordering::Release);
269            self.last_refill.store(now, Ordering::Release);
270
271            // Clear timestamps when doing a full refill
272            if let Some(mut recent) = self.recent_timestamps.try_lock() {
273                recent.clear();
274            }
275        } else if elapsed_ns > 0 {
276            // Partial refill based on elapsed time
277            let refill_fraction = elapsed_ns as f64 / self.window_ns as f64;
278            let refill_amount = (refill_fraction * self.max_tokens as f64) as usize;
279
280            if refill_amount > 0 {
281                // Add tokens up to max
282                let current = self.available_tokens.load(Ordering::Relaxed);
283                let new_count = std::cmp::min(current + refill_amount, self.max_tokens);
284
285                if new_count > current {
286                    self.available_tokens.store(new_count, Ordering::Release);
287
288                    // Update last refill time proportionally
289                    let new_last =
290                        last + (elapsed_ns * refill_amount as u64 / self.max_tokens as u64);
291                    self.last_refill.store(new_last, Ordering::Release);
292
293                    // Prune old timestamps
294                    if let Some(mut recent) = self.recent_timestamps.try_lock() {
295                        let cutoff = now.saturating_sub(self.window_ns);
296                        while let Some(&first) = recent.first() {
297                            if first < cutoff {
298                                recent.remove(0);
299                            } else {
300                                break;
301                            }
302                        }
303                    }
304                }
305            }
306        }
307
308        // Release the refill lock
309        self.refilling.store(false, Ordering::Release);
310    }
311
312    /// Estimate wait time until next token becomes available
313    #[inline]
314    fn estimate_wait_time(&self) -> u64 {
315        static CLOCK: std::sync::OnceLock<Clock> = std::sync::OnceLock::new();
316        let clock = CLOCK.get_or_init(Clock::new);
317        let now = clock.raw();
318
319        // If we have timestamps, use the oldest one to estimate when it will expire
320        if let Some(recent_timestamps) = self.recent_timestamps.try_lock()
321            && let Some(&oldest) = recent_timestamps.first()
322        {
323            // Calculate when the oldest timestamp leaves the window
324            let token_expiry = oldest + self.window_ns;
325            if token_expiry > now {
326                return token_expiry - now;
327            }
328        }
329
330        // If we can't determine from timestamps, estimate based on refill rate
331        let last_refill = self.last_refill.load(Ordering::Relaxed);
332        let next_token_time = last_refill + (self.window_ns / self.max_tokens as u64);
333
334        next_token_time.saturating_sub(now)
335    }
336
337    /// Get the maximum tokens
338    #[inline]
339    pub const fn max_tokens(&self) -> usize {
340        self.max_tokens
341    }
342
343    /// Get the current available tokens
344    #[inline]
345    pub fn available_tokens(&self) -> usize {
346        self.available_tokens.load(Ordering::Relaxed)
347    }
348
349    /// Get the window duration
350    #[inline]
351    pub const fn window(&self) -> Duration {
352        Duration::from_nanos(self.window_ns)
353    }
354}
355
356impl Drop for LockFreeRateLimiter {
357    fn drop(&mut self) {
358        // Nothing to clean up - we use atomic primitives
359    }
360}
361
362impl Clone for LockFreeRateLimiter {
363    fn clone(&self) -> Self {
364        // Cloning just creates a new reference to the same limiter
365        // since all internal state is managed with atomics
366        Self {
367            max_tokens: self.max_tokens,
368            available_tokens: AtomicUsize::new(self.available_tokens.load(Ordering::Relaxed)),
369            window_ns: self.window_ns,
370            last_refill: AtomicU64::new(self.last_refill.load(Ordering::Relaxed)),
371            refilling: AtomicBool::new(false),
372            recent_timestamps: Mutex::new(SmallVec::new()),
373        }
374    }
375}
376
377impl fmt::Debug for LockFreeRateLimiter {
378    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
379        f.debug_struct("LockFreeRateLimiter")
380            .field("max_tokens", &self.max_tokens)
381            .field(
382                "available_tokens",
383                &self.available_tokens.load(Ordering::Relaxed),
384            )
385            .field("window_ns", &self.window_ns)
386            .field("last_refill", &self.last_refill.load(Ordering::Relaxed))
387            .field("refilling", &self.refilling.load(Ordering::Relaxed))
388            .finish()
389    }
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395    use std::sync::Arc;
396    use std::thread;
397
398    #[test]
399    fn test_basic_rate_limiting() {
400        // Initialize a static clock to ensure consistent timestamps
401        static CLOCK: std::sync::OnceLock<Clock> = std::sync::OnceLock::new();
402        let clock = CLOCK.get_or_init(Clock::new);
403
404        // Create a rate limiter with a very small window and few tokens for testing
405        let limiter =
406            LockFreeRateLimiter::new(NonZeroUsize::new(5).unwrap(), Duration::from_millis(500));
407
408        // Override the last_refill time to ensure we have a clean start
409        let now = clock.raw();
410        limiter.last_refill.store(now, Ordering::SeqCst);
411
412        // Should be able to acquire all 5 tokens immediately
413        for i in 0..5 {
414            let token = limiter.try_acquire();
415            assert!(token.is_acquired(), "Failed to acquire token {i}");
416        }
417
418        // Reset available tokens to ensure the test state is clean
419        limiter.available_tokens.store(0, Ordering::SeqCst);
420
421        // 6th token should fail since we manually set available tokens to 0
422        let token = limiter.try_acquire();
423        assert!(
424            !token.is_acquired(),
425            "Token should not be acquired when available_tokens is 0"
426        );
427
428        // Wait for tokens to refill (longer than the window to ensure refill)
429        thread::sleep(Duration::from_millis(600));
430
431        // Force a refill
432        limiter.refill_tokens();
433
434        // Should be able to acquire tokens again
435        let token = limiter.try_acquire();
436        assert!(token.is_acquired(), "Failed to acquire token after refill");
437    }
438
439    #[test]
440    fn test_token_guard_release() {
441        let limiter =
442            LockFreeRateLimiter::new(NonZeroUsize::new(1).unwrap(), Duration::from_millis(100));
443
444        // Manually set available tokens to 1 to ensure test consistency
445        limiter.available_tokens.store(1, Ordering::SeqCst);
446
447        // Acquire the only token
448        let token = limiter.try_acquire();
449        assert!(token.is_acquired());
450
451        // Second attempt should fail
452        let token2 = limiter.try_acquire();
453        assert!(!token2.is_acquired());
454
455        // Release the first token
456        token.release();
457
458        // Should be able to acquire again
459        let token3 = limiter.try_acquire();
460        assert!(token3.is_acquired());
461    }
462
463    #[test]
464    fn test_concurrent_rate_limiting() {
465        // Use a larger token pool and window for this test
466        let limiter = Arc::new(LockFreeRateLimiter::new(
467            NonZeroUsize::new(100).unwrap(),
468            Duration::from_millis(1000),
469        ));
470
471        // Manually set available tokens to the maximum to ensure test starts in a known state
472        limiter.available_tokens.store(100, Ordering::SeqCst);
473
474        // Create a barrier to ensure all threads start at approximately the same time
475        let barrier = Arc::new(std::sync::Barrier::new(11)); // 10 threads + main thread
476
477        let threads: Vec<_> = (0..10)
478            .map(|_| {
479                let limiter_clone = Arc::clone(&limiter);
480                let barrier_clone = Arc::clone(&barrier);
481                thread::spawn(move || {
482                    // Wait for all threads to be ready
483                    barrier_clone.wait();
484
485                    let mut success_count = 0;
486                    for _ in 0..20 {
487                        let token = limiter_clone.try_acquire();
488                        if token.is_acquired() {
489                            success_count += 1;
490                            // Hold the token briefly, but don't sleep too long
491                            // to avoid test taking too long
492                            thread::yield_now();
493                        }
494                    }
495                    success_count
496                })
497            })
498            .collect();
499
500        // Wait with all threads
501        barrier.wait();
502
503        let total_success: usize = threads
504            .into_iter()
505            .map(|handle| handle.join().unwrap())
506            .sum();
507
508        // Total successful acquisitions should be at least 100 (the initial token count)
509        // but not more than 200 (could be more than 100 due to returns during test)
510        assert!(
511            (100..=200).contains(&total_success),
512            "Expected 100-200 successful acquisitions, got {total_success}"
513        );
514    }
515
516    #[test]
517    fn test_blocking_acquire() {
518        // Create a static instance of Clock to use for timing
519        static CLOCK: std::sync::OnceLock<Clock> = std::sync::OnceLock::new();
520        let clock = CLOCK.get_or_init(Clock::new);
521
522        // Use a longer duration to ensure we can measure the wait time accurately
523        let limiter =
524            LockFreeRateLimiter::new(NonZeroUsize::new(1).unwrap(), Duration::from_millis(100));
525
526        // Manually set available tokens to 0 to ensure we have to wait
527        limiter.available_tokens.store(0, Ordering::SeqCst);
528
529        // Set last_refill to current time
530        let now = clock.raw();
531        limiter.last_refill.store(now, Ordering::SeqCst);
532
533        // Start a separate thread that will release a token after a delay
534        let limiter_clone = Arc::new(limiter);
535        let thread_limiter = Arc::clone(&limiter_clone);
536
537        let handle = thread::spawn(move || {
538            // Wait to simulate a delay
539            thread::sleep(Duration::from_millis(50));
540            // Add a token
541            thread_limiter
542                .available_tokens
543                .fetch_add(1, Ordering::SeqCst);
544        });
545
546        // Start a timer
547        let start = clock.now();
548
549        // This should block until the token is available from the other thread
550        let token = limiter_clone.acquire();
551        assert!(
552            token.is_acquired(),
553            "Token should be acquired after waiting"
554        );
555
556        // Should have waited at least 40ms
557        let elapsed = clock.now().duration_since(start);
558        assert!(
559            elapsed >= Duration::from_millis(40),
560            "Expected to wait at least 40ms, but only waited {elapsed:?}"
561        );
562
563        // Ensure the thread completes
564        handle.join().unwrap();
565    }
566}