rusty_feeder/limit/
lockfree_rate_limiter.rs1use std::{
2 fmt,
3 num::NonZeroUsize,
4 sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
5 time::Duration,
6};
7
8use parking_lot::Mutex;
9use quanta::Clock;
10use smallvec::SmallVec;
11
12#[repr(align(64))] pub struct LockFreeRateLimiter {
19 max_tokens: usize,
21
22 available_tokens: AtomicUsize,
24
25 window_ns: u64,
27
28 last_refill: AtomicU64,
30
31 refilling: AtomicBool,
33
34 recent_timestamps: Mutex<SmallVec<[u64; 64]>>,
36}
37
38#[derive(Debug)]
40pub struct TokenGuard<'a> {
41 limiter: &'a LockFreeRateLimiter,
42 acquired: bool,
43}
44
45impl<'a> TokenGuard<'a> {
46 const fn new(limiter: &'a LockFreeRateLimiter, acquired: bool) -> Self {
47 Self { limiter, acquired }
48 }
49
50 #[inline]
52 pub const fn is_acquired(&self) -> bool {
53 self.acquired
54 }
55
56 #[inline]
58 pub fn release(mut self) {
59 if self.acquired {
60 self.limiter.release_token();
61 self.acquired = false;
62 }
63 }
64}
65
66impl Drop for TokenGuard<'_> {
67 fn drop(&mut self) {
68 if self.acquired {
69 self.limiter.release_token();
70 }
71 }
72}
73
74impl LockFreeRateLimiter {
75 #[must_use]
77 pub fn new(max_tokens: NonZeroUsize, window: Duration) -> Self {
78 static CLOCK: std::sync::OnceLock<Clock> = std::sync::OnceLock::new();
79 let clock = CLOCK.get_or_init(Clock::new);
80 let now = clock.raw();
81
82 Self {
83 max_tokens: max_tokens.get(),
84 available_tokens: AtomicUsize::new(max_tokens.get()),
85 window_ns: window.as_nanos() as u64,
86 last_refill: AtomicU64::new(now),
87 refilling: AtomicBool::new(false),
88 recent_timestamps: Mutex::new(SmallVec::new()),
89 }
90 }
91
92 #[inline(always)]
94 pub fn try_acquire(&self) -> TokenGuard<'_> {
95 let mut current = self.available_tokens.load(Ordering::Acquire);
97
98 while current > 0 {
99 match self.available_tokens.compare_exchange_weak(
100 current,
101 current - 1,
102 Ordering::AcqRel,
103 Ordering::Relaxed,
104 ) {
105 Ok(_) => {
106 static CLOCK: std::sync::OnceLock<Clock> = std::sync::OnceLock::new();
109 let clock = CLOCK.get_or_init(Clock::new);
110 let now = clock.raw();
111 if let Some(mut recent) = self.recent_timestamps.try_lock() {
112 recent.push(now);
113
114 if recent.len() > 128 {
116 recent.drain(0..64);
117 }
118 }
119
120 if current <= self.max_tokens / 4 {
122 self.try_refill_tokens();
123 }
124
125 return TokenGuard::new(self, true);
126 }
127 Err(actual) => {
128 current = actual;
130 }
131 }
132 }
133
134 self.try_refill_tokens();
136
137 current = self.available_tokens.load(Ordering::Acquire);
138 if current > 0
139 && self
140 .available_tokens
141 .compare_exchange_weak(current, current - 1, Ordering::AcqRel, Ordering::Relaxed)
142 .is_ok()
143 {
144 static CLOCK: std::sync::OnceLock<Clock> = std::sync::OnceLock::new();
145 let clock = CLOCK.get_or_init(Clock::new);
146 let now = clock.raw();
147 if let Some(mut recent) = self.recent_timestamps.try_lock() {
148 recent.push(now);
149 }
150 return TokenGuard::new(self, true);
151 }
152
153 TokenGuard::new(self, false)
155 }
156
157 #[inline]
159 pub fn acquire(&self) -> TokenGuard<'_> {
160 let guard = self.try_acquire();
162 if guard.is_acquired() {
163 return guard;
164 }
165
166 loop {
168 self.refill_tokens();
170
171 let guard = self.try_acquire();
173 if guard.is_acquired() {
174 return guard;
175 }
176
177 let wait_time = self.estimate_wait_time();
179
180 if wait_time > 0 {
182 std::thread::sleep(Duration::from_nanos(wait_time));
183 } else {
184 std::thread::yield_now();
186 }
187 }
188 }
189
190 #[inline]
192 fn release_token(&self) {
193 self.available_tokens.fetch_add(1, Ordering::Release);
195 }
196
197 #[inline]
199 fn try_refill_tokens(&self) {
200 if self
202 .refilling
203 .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
204 .is_err()
205 {
206 return; }
208
209 static CLOCK: std::sync::OnceLock<Clock> = std::sync::OnceLock::new();
211 let clock = CLOCK.get_or_init(Clock::new);
212 let now = clock.raw();
213 let last = self.last_refill.load(Ordering::Relaxed);
214
215 let elapsed_ns = now.saturating_sub(last);
217 if elapsed_ns >= self.window_ns / 10 {
218 let refill_amount =
221 ((elapsed_ns as f64 / self.window_ns as f64) * self.max_tokens as f64) as usize;
222
223 if refill_amount > 0 {
224 let current = self.available_tokens.load(Ordering::Relaxed);
226 let new_count = std::cmp::min(current + refill_amount, self.max_tokens);
227
228 if new_count > current {
229 self.available_tokens.store(new_count, Ordering::Release);
230
231 let new_last =
233 last + (elapsed_ns * refill_amount as u64 / self.max_tokens as u64);
234 self.last_refill.store(new_last, Ordering::Release);
235 }
236 }
237 }
238
239 self.refilling.store(false, Ordering::Release);
241 }
242
243 #[inline]
245 fn refill_tokens(&self) {
246 if self
249 .refilling
250 .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
251 .is_err()
252 {
253 return; }
255
256 static CLOCK: std::sync::OnceLock<Clock> = std::sync::OnceLock::new();
258 let clock = CLOCK.get_or_init(Clock::new);
259 let now = clock.raw();
260 let last = self.last_refill.load(Ordering::Relaxed);
261
262 let elapsed_ns = now.saturating_sub(last);
264
265 if elapsed_ns >= self.window_ns {
267 self.available_tokens
268 .store(self.max_tokens, Ordering::Release);
269 self.last_refill.store(now, Ordering::Release);
270
271 if let Some(mut recent) = self.recent_timestamps.try_lock() {
273 recent.clear();
274 }
275 } else if elapsed_ns > 0 {
276 let refill_fraction = elapsed_ns as f64 / self.window_ns as f64;
278 let refill_amount = (refill_fraction * self.max_tokens as f64) as usize;
279
280 if refill_amount > 0 {
281 let current = self.available_tokens.load(Ordering::Relaxed);
283 let new_count = std::cmp::min(current + refill_amount, self.max_tokens);
284
285 if new_count > current {
286 self.available_tokens.store(new_count, Ordering::Release);
287
288 let new_last =
290 last + (elapsed_ns * refill_amount as u64 / self.max_tokens as u64);
291 self.last_refill.store(new_last, Ordering::Release);
292
293 if let Some(mut recent) = self.recent_timestamps.try_lock() {
295 let cutoff = now.saturating_sub(self.window_ns);
296 while let Some(&first) = recent.first() {
297 if first < cutoff {
298 recent.remove(0);
299 } else {
300 break;
301 }
302 }
303 }
304 }
305 }
306 }
307
308 self.refilling.store(false, Ordering::Release);
310 }
311
312 #[inline]
314 fn estimate_wait_time(&self) -> u64 {
315 static CLOCK: std::sync::OnceLock<Clock> = std::sync::OnceLock::new();
316 let clock = CLOCK.get_or_init(Clock::new);
317 let now = clock.raw();
318
319 if let Some(recent_timestamps) = self.recent_timestamps.try_lock()
321 && let Some(&oldest) = recent_timestamps.first()
322 {
323 let token_expiry = oldest + self.window_ns;
325 if token_expiry > now {
326 return token_expiry - now;
327 }
328 }
329
330 let last_refill = self.last_refill.load(Ordering::Relaxed);
332 let next_token_time = last_refill + (self.window_ns / self.max_tokens as u64);
333
334 next_token_time.saturating_sub(now)
335 }
336
337 #[inline]
339 pub const fn max_tokens(&self) -> usize {
340 self.max_tokens
341 }
342
343 #[inline]
345 pub fn available_tokens(&self) -> usize {
346 self.available_tokens.load(Ordering::Relaxed)
347 }
348
349 #[inline]
351 pub const fn window(&self) -> Duration {
352 Duration::from_nanos(self.window_ns)
353 }
354}
355
356impl Drop for LockFreeRateLimiter {
357 fn drop(&mut self) {
358 }
360}
361
362impl Clone for LockFreeRateLimiter {
363 fn clone(&self) -> Self {
364 Self {
367 max_tokens: self.max_tokens,
368 available_tokens: AtomicUsize::new(self.available_tokens.load(Ordering::Relaxed)),
369 window_ns: self.window_ns,
370 last_refill: AtomicU64::new(self.last_refill.load(Ordering::Relaxed)),
371 refilling: AtomicBool::new(false),
372 recent_timestamps: Mutex::new(SmallVec::new()),
373 }
374 }
375}
376
377impl fmt::Debug for LockFreeRateLimiter {
378 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
379 f.debug_struct("LockFreeRateLimiter")
380 .field("max_tokens", &self.max_tokens)
381 .field(
382 "available_tokens",
383 &self.available_tokens.load(Ordering::Relaxed),
384 )
385 .field("window_ns", &self.window_ns)
386 .field("last_refill", &self.last_refill.load(Ordering::Relaxed))
387 .field("refilling", &self.refilling.load(Ordering::Relaxed))
388 .finish()
389 }
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395 use std::sync::Arc;
396 use std::thread;
397
398 #[test]
399 fn test_basic_rate_limiting() {
400 static CLOCK: std::sync::OnceLock<Clock> = std::sync::OnceLock::new();
402 let clock = CLOCK.get_or_init(Clock::new);
403
404 let limiter =
406 LockFreeRateLimiter::new(NonZeroUsize::new(5).unwrap(), Duration::from_millis(500));
407
408 let now = clock.raw();
410 limiter.last_refill.store(now, Ordering::SeqCst);
411
412 for i in 0..5 {
414 let token = limiter.try_acquire();
415 assert!(token.is_acquired(), "Failed to acquire token {i}");
416 }
417
418 limiter.available_tokens.store(0, Ordering::SeqCst);
420
421 let token = limiter.try_acquire();
423 assert!(
424 !token.is_acquired(),
425 "Token should not be acquired when available_tokens is 0"
426 );
427
428 thread::sleep(Duration::from_millis(600));
430
431 limiter.refill_tokens();
433
434 let token = limiter.try_acquire();
436 assert!(token.is_acquired(), "Failed to acquire token after refill");
437 }
438
439 #[test]
440 fn test_token_guard_release() {
441 let limiter =
442 LockFreeRateLimiter::new(NonZeroUsize::new(1).unwrap(), Duration::from_millis(100));
443
444 limiter.available_tokens.store(1, Ordering::SeqCst);
446
447 let token = limiter.try_acquire();
449 assert!(token.is_acquired());
450
451 let token2 = limiter.try_acquire();
453 assert!(!token2.is_acquired());
454
455 token.release();
457
458 let token3 = limiter.try_acquire();
460 assert!(token3.is_acquired());
461 }
462
463 #[test]
464 fn test_concurrent_rate_limiting() {
465 let limiter = Arc::new(LockFreeRateLimiter::new(
467 NonZeroUsize::new(100).unwrap(),
468 Duration::from_millis(1000),
469 ));
470
471 limiter.available_tokens.store(100, Ordering::SeqCst);
473
474 let barrier = Arc::new(std::sync::Barrier::new(11)); let threads: Vec<_> = (0..10)
478 .map(|_| {
479 let limiter_clone = Arc::clone(&limiter);
480 let barrier_clone = Arc::clone(&barrier);
481 thread::spawn(move || {
482 barrier_clone.wait();
484
485 let mut success_count = 0;
486 for _ in 0..20 {
487 let token = limiter_clone.try_acquire();
488 if token.is_acquired() {
489 success_count += 1;
490 thread::yield_now();
493 }
494 }
495 success_count
496 })
497 })
498 .collect();
499
500 barrier.wait();
502
503 let total_success: usize = threads
504 .into_iter()
505 .map(|handle| handle.join().unwrap())
506 .sum();
507
508 assert!(
511 (100..=200).contains(&total_success),
512 "Expected 100-200 successful acquisitions, got {total_success}"
513 );
514 }
515
516 #[test]
517 fn test_blocking_acquire() {
518 static CLOCK: std::sync::OnceLock<Clock> = std::sync::OnceLock::new();
520 let clock = CLOCK.get_or_init(Clock::new);
521
522 let limiter =
524 LockFreeRateLimiter::new(NonZeroUsize::new(1).unwrap(), Duration::from_millis(100));
525
526 limiter.available_tokens.store(0, Ordering::SeqCst);
528
529 let now = clock.raw();
531 limiter.last_refill.store(now, Ordering::SeqCst);
532
533 let limiter_clone = Arc::new(limiter);
535 let thread_limiter = Arc::clone(&limiter_clone);
536
537 let handle = thread::spawn(move || {
538 thread::sleep(Duration::from_millis(50));
540 thread_limiter
542 .available_tokens
543 .fetch_add(1, Ordering::SeqCst);
544 });
545
546 let start = clock.now();
548
549 let token = limiter_clone.acquire();
551 assert!(
552 token.is_acquired(),
553 "Token should be acquired after waiting"
554 );
555
556 let elapsed = clock.now().duration_since(start);
558 assert!(
559 elapsed >= Duration::from_millis(40),
560 "Expected to wait at least 40ms, but only waited {elapsed:?}"
561 );
562
563 handle.join().unwrap();
565 }
566}