http_handle/distributed_rate_limit.rs
1// SPDX-License-Identifier: AGPL-3.0-only
2// Copyright (c) 2026 Sebastien Rousseau
3
4//! Distributed rate-limiting adapters and backend contracts.
5
6use crate::error::ServerError;
7use std::collections::HashMap;
8use std::net::IpAddr;
9use std::sync::{Arc, Mutex};
10use std::time::{Duration, Instant};
11
12/// Backend trait for incrementing a rate-limit key in a time window.
13///
14/// # Examples
15///
16/// ```rust
17/// use http_handle::distributed_rate_limit::RateLimitBackend;
18/// # let _ = std::any::TypeId::of::<&dyn RateLimitBackend>();
19/// assert_eq!(2 + 2, 4);
20/// ```
21///
22/// # Panics
23///
24/// Trait usage does not panic by itself.
25pub trait RateLimitBackend: Send + Sync + std::fmt::Debug {
26 /// Increments key and returns current hit count for the active window.
27 fn increment_and_get(
28 &self,
29 key: &str,
30 window_secs: u64,
31 ) -> Result<u64, ServerError>;
32}
33
34/// Shared rate limiter that works against pluggable backends.
35///
36/// # Examples
37///
38/// ```rust
39/// use http_handle::distributed_rate_limit::{DistributedRateLimiter, InMemoryBackend};
40/// let _limiter = DistributedRateLimiter::new(InMemoryBackend::default(), "ip", 100, 60);
41/// assert_eq!(1, 1);
42/// ```
43///
44/// # Panics
45///
46/// This type does not panic.
47#[derive(Clone, Debug)]
48pub struct DistributedRateLimiter<B: RateLimitBackend> {
49 backend: Arc<B>,
50 namespace: String,
51 limit_per_window: u64,
52 window_secs: u64,
53}
54
55impl<B: RateLimitBackend> DistributedRateLimiter<B> {
56 /// Creates a distributed limiter with explicit namespace and limits.
57 ///
58 /// # Examples
59 ///
60 /// ```rust
61 /// use http_handle::distributed_rate_limit::{DistributedRateLimiter, InMemoryBackend};
62 /// let _ = DistributedRateLimiter::new(InMemoryBackend::default(), "ip", 10, 60);
63 /// assert_eq!(1, 1);
64 /// ```
65 ///
66 /// # Panics
67 ///
68 /// This function does not panic.
69 pub fn new(
70 backend: B,
71 namespace: impl Into<String>,
72 limit_per_window: u64,
73 window_secs: u64,
74 ) -> Self {
75 Self {
76 backend: Arc::new(backend),
77 namespace: namespace.into(),
78 limit_per_window: limit_per_window.max(1),
79 window_secs: window_secs.max(1),
80 }
81 }
82
83 /// Returns true when the source should be throttled.
84 ///
85 /// # Examples
86 ///
87 /// ```rust
88 /// use http_handle::distributed_rate_limit::{DistributedRateLimiter, InMemoryBackend};
89 /// use std::net::IpAddr;
90 /// let limiter = DistributedRateLimiter::new(InMemoryBackend::default(), "ip", 1, 60);
91 /// let ip: IpAddr = "127.0.0.1".parse().expect("ip");
92 /// let _ = limiter.is_limited(ip);
93 /// assert_eq!(1, 1);
94 /// ```
95 ///
96 /// # Errors
97 ///
98 /// Returns backend errors when increment operations fail.
99 ///
100 /// # Panics
101 ///
102 /// This function does not panic.
103 pub fn is_limited(
104 &self,
105 source: IpAddr,
106 ) -> Result<bool, ServerError> {
107 let key = format!("{}:{}", self.namespace, source);
108 let current =
109 self.backend.increment_and_get(&key, self.window_secs)?;
110 Ok(current > self.limit_per_window)
111 }
112}
113
114/// In-memory backend useful for local fallback mode and tests.
115///
116/// # Examples
117///
118/// ```rust
119/// use http_handle::distributed_rate_limit::InMemoryBackend;
120/// let _backend = InMemoryBackend::default();
121/// assert_eq!(1, 1);
122/// ```
123///
124/// # Panics
125///
126/// This type does not panic.
127#[derive(Debug, Default)]
128pub struct InMemoryBackend {
129 state: Mutex<HashMap<String, Vec<Instant>>>,
130}
131
132impl RateLimitBackend for InMemoryBackend {
133 fn increment_and_get(
134 &self,
135 key: &str,
136 window_secs: u64,
137 ) -> Result<u64, ServerError> {
138 let now = Instant::now();
139 let mut state = self.state.lock().map_err(|_| {
140 ServerError::Custom("rate state poisoned".into())
141 })?;
142 let hits = state.entry(key.to_string()).or_default();
143 hits.retain(|ts| {
144 now.duration_since(*ts) <= Duration::from_secs(window_secs)
145 });
146 hits.push(now);
147 Ok(hits.len() as u64)
148 }
149}
150
151/// Minimal Redis-like client contract.
152///
153/// # Examples
154///
155/// ```rust
156/// use http_handle::distributed_rate_limit::RedisClient;
157/// # let _ = std::any::TypeId::of::<&dyn RedisClient>();
158/// assert_eq!(1, 1);
159/// ```
160///
161/// # Panics
162///
163/// Trait usage does not panic by itself.
164pub trait RedisClient: Send + Sync + std::fmt::Debug {
165 /// Increments key, sets TTL as needed, and returns current count.
166 fn incr_with_ttl(
167 &self,
168 key: &str,
169 ttl_secs: u64,
170 ) -> Result<u64, ServerError>;
171}
172
173/// Redis backend adapter.
174///
175/// # Examples
176///
177/// ```rust
178/// use http_handle::distributed_rate_limit::{RedisBackend, RedisClient};
179/// use http_handle::ServerError;
180/// #[derive(Debug)]
181/// struct Dummy;
182/// impl RedisClient for Dummy {
183/// fn incr_with_ttl(&self, _key: &str, _ttl_secs: u64) -> Result<u64, ServerError> { Ok(1) }
184/// }
185/// let _backend = RedisBackend::new(Dummy);
186/// assert_eq!(1, 1);
187/// ```
188///
189/// # Panics
190///
191/// This type does not panic.
192#[derive(Debug)]
193pub struct RedisBackend<C: RedisClient> {
194 client: C,
195}
196
197impl<C: RedisClient> RedisBackend<C> {
198 /// Creates a new Redis backend adapter.
199 ///
200 /// # Examples
201 ///
202 /// ```rust
203 /// use http_handle::distributed_rate_limit::{RedisBackend, RedisClient};
204 /// use http_handle::ServerError;
205 /// #[derive(Debug)]
206 /// struct Dummy;
207 /// impl RedisClient for Dummy {
208 /// fn incr_with_ttl(&self, _key: &str, _ttl_secs: u64) -> Result<u64, ServerError> { Ok(1) }
209 /// }
210 /// let _backend = RedisBackend::new(Dummy);
211 /// assert_eq!(1, 1);
212 /// ```
213 ///
214 /// # Panics
215 ///
216 /// This function does not panic.
217 pub fn new(client: C) -> Self {
218 Self { client }
219 }
220}
221
222impl<C: RedisClient> RateLimitBackend for RedisBackend<C> {
223 fn increment_and_get(
224 &self,
225 key: &str,
226 window_secs: u64,
227 ) -> Result<u64, ServerError> {
228 self.client.incr_with_ttl(key, window_secs)
229 }
230}
231
232/// Minimal Memcached-like client contract.
233///
234/// # Examples
235///
236/// ```rust
237/// use http_handle::distributed_rate_limit::MemcachedClient;
238/// # let _ = std::any::TypeId::of::<&dyn MemcachedClient>();
239/// assert_eq!(1, 1);
240/// ```
241///
242/// # Panics
243///
244/// Trait usage does not panic by itself.
245pub trait MemcachedClient: Send + Sync + std::fmt::Debug {
246 /// Increments key and returns current count.
247 fn incr(
248 &self,
249 key: &str,
250 initial: u64,
251 ttl_secs: u32,
252 ) -> Result<u64, ServerError>;
253}
254
255/// Memcached backend adapter.
256///
257/// # Examples
258///
259/// ```rust
260/// use http_handle::distributed_rate_limit::{MemcachedBackend, MemcachedClient};
261/// use http_handle::ServerError;
262/// #[derive(Debug)]
263/// struct Dummy;
264/// impl MemcachedClient for Dummy {
265/// fn incr(&self, _key: &str, _initial: u64, _ttl_secs: u32) -> Result<u64, ServerError> { Ok(1) }
266/// }
267/// let _backend = MemcachedBackend::new(Dummy);
268/// assert_eq!(1, 1);
269/// ```
270///
271/// # Panics
272///
273/// This type does not panic.
274#[derive(Debug)]
275pub struct MemcachedBackend<C: MemcachedClient> {
276 client: C,
277}
278
279impl<C: MemcachedClient> MemcachedBackend<C> {
280 /// Creates a new Memcached backend adapter.
281 ///
282 /// # Examples
283 ///
284 /// ```rust
285 /// use http_handle::distributed_rate_limit::{MemcachedBackend, MemcachedClient};
286 /// use http_handle::ServerError;
287 /// #[derive(Debug)]
288 /// struct Dummy;
289 /// impl MemcachedClient for Dummy {
290 /// fn incr(&self, _key: &str, _initial: u64, _ttl_secs: u32) -> Result<u64, ServerError> { Ok(1) }
291 /// }
292 /// let _backend = MemcachedBackend::new(Dummy);
293 /// assert_eq!(1, 1);
294 /// ```
295 ///
296 /// # Panics
297 ///
298 /// This function does not panic.
299 pub fn new(client: C) -> Self {
300 Self { client }
301 }
302}
303
304impl<C: MemcachedClient> RateLimitBackend for MemcachedBackend<C> {
305 fn increment_and_get(
306 &self,
307 key: &str,
308 window_secs: u64,
309 ) -> Result<u64, ServerError> {
310 self.client.incr(key, 1, window_secs as u32)
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317
318 #[derive(Debug, Default)]
319 struct MockRedis {
320 counts: Mutex<HashMap<String, u64>>,
321 }
322 impl RedisClient for MockRedis {
323 fn incr_with_ttl(
324 &self,
325 key: &str,
326 _ttl_secs: u64,
327 ) -> Result<u64, ServerError> {
328 let mut counts = self
329 .counts
330 .lock()
331 .map_err(|_| ServerError::Custom("poisoned".into()))?;
332 let entry = counts.entry(key.to_string()).or_insert(0);
333 *entry += 1;
334 Ok(*entry)
335 }
336 }
337
338 #[derive(Debug, Default)]
339 struct MockMemcached {
340 counts: Mutex<HashMap<String, u64>>,
341 }
342 impl MemcachedClient for MockMemcached {
343 fn incr(
344 &self,
345 key: &str,
346 initial: u64,
347 _ttl_secs: u32,
348 ) -> Result<u64, ServerError> {
349 let mut counts = self
350 .counts
351 .lock()
352 .map_err(|_| ServerError::Custom("poisoned".into()))?;
353 if let Some(entry) = counts.get_mut(key) {
354 *entry += 1;
355 Ok(*entry)
356 } else {
357 let _ = counts.insert(key.to_string(), initial);
358 Ok(initial)
359 }
360 }
361 }
362
363 #[test]
364 fn in_memory_backend_enforces_limit() {
365 let limiter = DistributedRateLimiter::new(
366 InMemoryBackend::default(),
367 "ip",
368 2,
369 60,
370 );
371 let ip: IpAddr = "127.0.0.1".parse().expect("ip");
372 assert!(!limiter.is_limited(ip).expect("limit"));
373 assert!(!limiter.is_limited(ip).expect("limit"));
374 assert!(limiter.is_limited(ip).expect("limit"));
375 }
376
377 #[test]
378 fn redis_adapter_routes_calls() {
379 let backend = RedisBackend::new(MockRedis::default());
380 let limiter = DistributedRateLimiter::new(backend, "ip", 1, 60);
381 let ip: IpAddr = "127.0.0.2".parse().expect("ip");
382 assert!(!limiter.is_limited(ip).expect("limit"));
383 assert!(limiter.is_limited(ip).expect("limit"));
384 }
385
386 #[test]
387 fn memcached_adapter_routes_calls() {
388 let backend = MemcachedBackend::new(MockMemcached::default());
389 let limiter = DistributedRateLimiter::new(backend, "ip", 1, 60);
390 let ip: IpAddr = "127.0.0.3".parse().expect("ip");
391 assert!(!limiter.is_limited(ip).expect("limit"));
392 assert!(limiter.is_limited(ip).expect("limit"));
393 }
394
395 #[test]
396 fn limiter_propagates_backend_errors() {
397 #[derive(Debug)]
398 struct FailingBackend;
399 impl RateLimitBackend for FailingBackend {
400 fn increment_and_get(
401 &self,
402 _key: &str,
403 _window_secs: u64,
404 ) -> Result<u64, ServerError> {
405 Err(ServerError::Custom("backend down".into()))
406 }
407 }
408
409 let limiter =
410 DistributedRateLimiter::new(FailingBackend, "ip", 0, 0);
411 let ip: IpAddr = "127.0.0.9".parse().expect("ip");
412 let err = limiter.is_limited(ip).expect_err("must fail");
413 assert!(err.to_string().contains("backend down"));
414 }
415
416 #[test]
417 fn in_memory_backend_maps_poisoned_lock_to_error() {
418 let backend = InMemoryBackend::default();
419 let _ = std::panic::catch_unwind(|| {
420 let _guard = backend.state.lock().expect("lock");
421 panic!("poison lock");
422 });
423 let err = backend
424 .increment_and_get("ip:127.0.0.1", 60)
425 .expect_err("poisoned lock should error");
426 assert!(err.to_string().contains("poisoned"));
427 }
428}