rusty_common/memory/
object_pool.rs1use parking_lot::Mutex;
7use std::cell::RefCell;
8use std::mem::MaybeUninit;
9use std::sync::Arc;
10
11pub struct LocalObjectPool<T, const N: usize = 64> {
13 pool: RefCell<Vec<T>>,
14 initializer: Box<dyn Fn() -> T>,
15}
16
17impl<T, const N: usize> LocalObjectPool<T, N> {
18 pub fn new(initializer: impl Fn() -> T + 'static) -> Self {
20 let mut pool = Vec::with_capacity(N);
21
22 for _ in 0..N {
24 pool.push(initializer());
25 }
26
27 Self {
28 pool: RefCell::new(pool),
29 initializer: Box::new(initializer),
30 }
31 }
32
33 #[inline(always)]
35 pub fn get(&self) -> PooledObject<T, N> {
36 let obj = self
37 .pool
38 .borrow_mut()
39 .pop()
40 .unwrap_or_else(|| (self.initializer)());
41 PooledObject {
42 inner: Some(obj),
43 pool: self as *const Self,
44 }
45 }
46
47 #[inline(always)]
49 fn return_object(&self, obj: T) {
50 let mut pool = self.pool.borrow_mut();
51 if pool.len() < N {
52 pool.push(obj);
53 }
54 }
56}
57
58pub struct ThreadSafeObjectPool<T: Send> {
60 pool: Arc<Mutex<Vec<T>>>,
61 capacity: usize,
62 initializer: Arc<dyn Fn() -> T + Send + Sync>,
63}
64
65impl<T: Send> ThreadSafeObjectPool<T> {
66 pub fn new(capacity: usize, initializer: impl Fn() -> T + Send + Sync + 'static) -> Self {
68 let mut pool = Vec::with_capacity(capacity);
69
70 for _ in 0..capacity {
72 pool.push(initializer());
73 }
74
75 Self {
76 pool: Arc::new(Mutex::new(pool)),
77 capacity,
78 initializer: Arc::new(initializer),
79 }
80 }
81
82 #[inline(always)]
84 pub fn get(&self) -> ThreadSafePooledObject<T> {
85 let obj = self
86 .pool
87 .lock()
88 .pop()
89 .unwrap_or_else(|| (self.initializer)());
90 ThreadSafePooledObject {
91 inner: Some(obj),
92 pool: Arc::clone(&self.pool),
93 capacity: self.capacity,
94 }
95 }
96}
97
98impl<T: Send> Clone for ThreadSafeObjectPool<T> {
99 fn clone(&self) -> Self {
100 Self {
101 pool: Arc::clone(&self.pool),
102 capacity: self.capacity,
103 initializer: Arc::clone(&self.initializer),
104 }
105 }
106}
107
108pub struct PooledObject<T, const N: usize = 64> {
110 inner: Option<T>,
111 pool: *const LocalObjectPool<T, N>,
112}
113
114impl<T, const N: usize> PooledObject<T, N> {
115 pub fn take(mut self) -> T {
117 self.inner.take().expect("Object already taken")
118 }
119}
120
121impl<T, const N: usize> std::ops::Deref for PooledObject<T, N> {
122 type Target = T;
123
124 fn deref(&self) -> &Self::Target {
125 self.inner.as_ref().expect("Object already taken")
126 }
127}
128
129impl<T, const N: usize> std::ops::DerefMut for PooledObject<T, N> {
130 fn deref_mut(&mut self) -> &mut Self::Target {
131 self.inner.as_mut().expect("Object already taken")
132 }
133}
134
135impl<T, const N: usize> Drop for PooledObject<T, N> {
136 fn drop(&mut self) {
137 if let Some(obj) = self.inner.take() {
138 unsafe {
139 (*self.pool).return_object(obj);
140 }
141 }
142 }
143}
144
145pub struct ThreadSafePooledObject<T: Send> {
147 inner: Option<T>,
148 pool: Arc<Mutex<Vec<T>>>,
149 capacity: usize,
150}
151
152impl<T: Send> ThreadSafePooledObject<T> {
153 pub fn take(mut self) -> T {
155 self.inner.take().expect("Object already taken")
156 }
157}
158
159impl<T: Send> std::ops::Deref for ThreadSafePooledObject<T> {
160 type Target = T;
161
162 fn deref(&self) -> &Self::Target {
163 self.inner.as_ref().expect("Object already taken")
164 }
165}
166
167impl<T: Send> std::ops::DerefMut for ThreadSafePooledObject<T> {
168 fn deref_mut(&mut self) -> &mut Self::Target {
169 self.inner.as_mut().expect("Object already taken")
170 }
171}
172
173impl<T: Send> Drop for ThreadSafePooledObject<T> {
174 fn drop(&mut self) {
175 if let Some(obj) = self.inner.take() {
176 let mut pool = self.pool.lock();
177 if pool.len() < self.capacity {
178 pool.push(obj);
179 }
180 }
182 }
183}
184
185pub struct FixedPool<T, const N: usize> {
187 storage: [MaybeUninit<T>; N],
188 available: [bool; N],
189 next_free: usize,
190}
191
192impl<T: Default, const N: usize> Default for FixedPool<T, N> {
193 fn default() -> Self {
194 Self::new()
195 }
196}
197
198impl<T: Default, const N: usize> FixedPool<T, N> {
199 pub const fn new() -> Self {
201 Self {
202 storage: unsafe { MaybeUninit::uninit().assume_init() },
203 available: [true; N],
204 next_free: 0,
205 }
206 }
207
208 pub fn init(&mut self) {
210 for i in 0..N {
211 self.storage[i] = MaybeUninit::new(T::default());
212 }
213 }
214
215 #[inline(always)]
217 pub fn try_get(&mut self) -> Option<&mut T> {
218 if self.available[self.next_free] {
220 self.available[self.next_free] = false;
221 let obj = unsafe { self.storage[self.next_free].assume_init_mut() };
222
223 self.next_free = (self.next_free + 1) % N;
225
226 return Some(obj);
227 }
228
229 for i in 0..N {
231 let index = (self.next_free + i) % N;
232 if self.available[index] {
233 self.available[index] = false;
234 self.next_free = (index + 1) % N;
235 return Some(unsafe { self.storage[index].assume_init_mut() });
236 }
237 }
238
239 None
240 }
241
242 #[inline(always)]
244 pub fn return_by_index(&mut self, index: usize) {
245 debug_assert!(index < N);
246 debug_assert!(!self.available[index]);
247 self.available[index] = true;
248 self.next_free = index;
249 }
250}
251
252pub type LocalObjectPool64<T> = LocalObjectPool<T, 64>;
255pub type LocalObjectPool32<T> = LocalObjectPool<T, 32>;
257pub type LocalObjectPool128<T> = LocalObjectPool<T, 128>;
259
260pub type PooledObject64<T> = PooledObject<T, 64>;
262pub type PooledObject32<T> = PooledObject<T, 32>;
264pub type PooledObject128<T> = PooledObject<T, 128>;
266
267pub use LocalObjectPool64 as DefaultLocalObjectPool;
269pub use PooledObject64 as DefaultPooledObject;
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 #[derive(Default, Debug, PartialEq)]
276 struct TestObject {
277 value: i32,
278 }
279
280 #[test]
281 fn test_local_pool() {
282 let pool: LocalObjectPool<TestObject, 64> = LocalObjectPool::new(TestObject::default);
283
284 let mut obj1 = pool.get();
285 obj1.value = 42;
286
287 let mut obj2 = pool.get();
288 obj2.value = 43;
289
290 let obj3 = pool.get();
292 assert_eq!(obj3.value, 0);
293
294 drop(obj1);
295 drop(obj2);
296
297 let obj4 = pool.get();
299 assert!(obj4.value == 42 || obj4.value == 43);
300 }
301
302 #[test]
303 fn test_thread_safe_pool() {
304 let pool = ThreadSafeObjectPool::new(10, TestObject::default);
305
306 let mut handles = vec![];
307
308 for i in 0..5 {
309 let pool = pool.clone();
310 let handle = std::thread::spawn(move || {
311 let mut obj = pool.get();
312 obj.value = i;
313 std::thread::sleep(std::time::Duration::from_millis(10));
314 });
315 handles.push(handle);
316 }
317
318 for handle in handles {
319 handle.join().unwrap();
320 }
321 }
322
323 #[test]
324 fn test_fixed_pool() {
325 let mut pool: FixedPool<TestObject, 4> = FixedPool::new();
326 pool.init();
327
328 let obj1 = pool.try_get().unwrap();
329 obj1.value = 1;
330
331 let obj2 = pool.try_get().unwrap();
332 obj2.value = 2;
333
334 pool.return_by_index(0);
335
336 let obj3 = pool.try_get().unwrap();
337 assert_eq!(obj3.value, 1); }
339
340 #[test]
341 fn test_fixed_pool_all_slots_taken() {
342 let mut pool: FixedPool<TestObject, 3> = FixedPool::new();
343 pool.init();
344
345 let _obj1 = pool.try_get().unwrap();
347 let _obj2 = pool.try_get().unwrap();
348 let _obj3 = pool.try_get().unwrap();
349
350 assert!(pool.try_get().is_none());
352
353 pool.return_by_index(1);
355
356 let _obj4 = pool.try_get();
358 assert!(_obj4.is_some());
359
360 assert!(pool.try_get().is_none());
362 }
363}