Skip to main content

http_handle/
distributed_rate_limit.rs

1// SPDX-License-Identifier: AGPL-3.0-only
2// Copyright (c) 2026 Sebastien Rousseau
3
4//! Distributed rate-limiting adapters and backend contracts.
5
6use crate::error::ServerError;
7use std::collections::HashMap;
8use std::net::IpAddr;
9use std::sync::{Arc, Mutex};
10use std::time::{Duration, Instant};
11
12/// Backend trait for incrementing a rate-limit key in a time window.
13///
14/// # Examples
15///
16/// ```rust
17/// use http_handle::distributed_rate_limit::RateLimitBackend;
18/// # let _ = std::any::TypeId::of::<&dyn RateLimitBackend>();
19/// assert_eq!(2 + 2, 4);
20/// ```
21///
22/// # Panics
23///
24/// Trait usage does not panic by itself.
25pub trait RateLimitBackend: Send + Sync + std::fmt::Debug {
26    /// Increments key and returns current hit count for the active window.
27    fn increment_and_get(
28        &self,
29        key: &str,
30        window_secs: u64,
31    ) -> Result<u64, ServerError>;
32}
33
34/// Shared rate limiter that works against pluggable backends.
35///
36/// # Examples
37///
38/// ```rust
39/// use http_handle::distributed_rate_limit::{DistributedRateLimiter, InMemoryBackend};
40/// let _limiter = DistributedRateLimiter::new(InMemoryBackend::default(), "ip", 100, 60);
41/// assert_eq!(1, 1);
42/// ```
43///
44/// # Panics
45///
46/// This type does not panic.
47#[derive(Clone, Debug)]
48pub struct DistributedRateLimiter<B: RateLimitBackend> {
49    backend: Arc<B>,
50    namespace: String,
51    limit_per_window: u64,
52    window_secs: u64,
53}
54
55impl<B: RateLimitBackend> DistributedRateLimiter<B> {
56    /// Creates a distributed limiter with explicit namespace and limits.
57    ///
58    /// # Examples
59    ///
60    /// ```rust
61    /// use http_handle::distributed_rate_limit::{DistributedRateLimiter, InMemoryBackend};
62    /// let _ = DistributedRateLimiter::new(InMemoryBackend::default(), "ip", 10, 60);
63    /// assert_eq!(1, 1);
64    /// ```
65    ///
66    /// # Panics
67    ///
68    /// This function does not panic.
69    pub fn new(
70        backend: B,
71        namespace: impl Into<String>,
72        limit_per_window: u64,
73        window_secs: u64,
74    ) -> Self {
75        Self {
76            backend: Arc::new(backend),
77            namespace: namespace.into(),
78            limit_per_window: limit_per_window.max(1),
79            window_secs: window_secs.max(1),
80        }
81    }
82
83    /// Returns true when the source should be throttled.
84    ///
85    /// # Examples
86    ///
87    /// ```rust
88    /// use http_handle::distributed_rate_limit::{DistributedRateLimiter, InMemoryBackend};
89    /// use std::net::IpAddr;
90    /// let limiter = DistributedRateLimiter::new(InMemoryBackend::default(), "ip", 1, 60);
91    /// let ip: IpAddr = "127.0.0.1".parse().expect("ip");
92    /// let _ = limiter.is_limited(ip);
93    /// assert_eq!(1, 1);
94    /// ```
95    ///
96    /// # Errors
97    ///
98    /// Returns backend errors when increment operations fail.
99    ///
100    /// # Panics
101    ///
102    /// This function does not panic.
103    pub fn is_limited(
104        &self,
105        source: IpAddr,
106    ) -> Result<bool, ServerError> {
107        let key = format!("{}:{}", self.namespace, source);
108        let current =
109            self.backend.increment_and_get(&key, self.window_secs)?;
110        Ok(current > self.limit_per_window)
111    }
112}
113
114/// In-memory backend useful for local fallback mode and tests.
115///
116/// # Examples
117///
118/// ```rust
119/// use http_handle::distributed_rate_limit::InMemoryBackend;
120/// let _backend = InMemoryBackend::default();
121/// assert_eq!(1, 1);
122/// ```
123///
124/// # Panics
125///
126/// This type does not panic.
127#[derive(Debug, Default)]
128pub struct InMemoryBackend {
129    state: Mutex<HashMap<String, Vec<Instant>>>,
130}
131
132impl RateLimitBackend for InMemoryBackend {
133    fn increment_and_get(
134        &self,
135        key: &str,
136        window_secs: u64,
137    ) -> Result<u64, ServerError> {
138        let now = Instant::now();
139        let mut state = self.state.lock().map_err(|_| {
140            ServerError::Custom("rate state poisoned".into())
141        })?;
142        let hits = state.entry(key.to_string()).or_default();
143        hits.retain(|ts| {
144            now.duration_since(*ts) <= Duration::from_secs(window_secs)
145        });
146        hits.push(now);
147        Ok(hits.len() as u64)
148    }
149}
150
151/// Minimal Redis-like client contract.
152///
153/// # Examples
154///
155/// ```rust
156/// use http_handle::distributed_rate_limit::RedisClient;
157/// # let _ = std::any::TypeId::of::<&dyn RedisClient>();
158/// assert_eq!(1, 1);
159/// ```
160///
161/// # Panics
162///
163/// Trait usage does not panic by itself.
164pub trait RedisClient: Send + Sync + std::fmt::Debug {
165    /// Increments key, sets TTL as needed, and returns current count.
166    fn incr_with_ttl(
167        &self,
168        key: &str,
169        ttl_secs: u64,
170    ) -> Result<u64, ServerError>;
171}
172
173/// Redis backend adapter.
174///
175/// # Examples
176///
177/// ```rust
178/// use http_handle::distributed_rate_limit::{RedisBackend, RedisClient};
179/// use http_handle::ServerError;
180/// #[derive(Debug)]
181/// struct Dummy;
182/// impl RedisClient for Dummy {
183///     fn incr_with_ttl(&self, _key: &str, _ttl_secs: u64) -> Result<u64, ServerError> { Ok(1) }
184/// }
185/// let _backend = RedisBackend::new(Dummy);
186/// assert_eq!(1, 1);
187/// ```
188///
189/// # Panics
190///
191/// This type does not panic.
192#[derive(Debug)]
193pub struct RedisBackend<C: RedisClient> {
194    client: C,
195}
196
197impl<C: RedisClient> RedisBackend<C> {
198    /// Creates a new Redis backend adapter.
199    ///
200    /// # Examples
201    ///
202    /// ```rust
203    /// use http_handle::distributed_rate_limit::{RedisBackend, RedisClient};
204    /// use http_handle::ServerError;
205    /// #[derive(Debug)]
206    /// struct Dummy;
207    /// impl RedisClient for Dummy {
208    ///     fn incr_with_ttl(&self, _key: &str, _ttl_secs: u64) -> Result<u64, ServerError> { Ok(1) }
209    /// }
210    /// let _backend = RedisBackend::new(Dummy);
211    /// assert_eq!(1, 1);
212    /// ```
213    ///
214    /// # Panics
215    ///
216    /// This function does not panic.
217    pub fn new(client: C) -> Self {
218        Self { client }
219    }
220}
221
222impl<C: RedisClient> RateLimitBackend for RedisBackend<C> {
223    fn increment_and_get(
224        &self,
225        key: &str,
226        window_secs: u64,
227    ) -> Result<u64, ServerError> {
228        self.client.incr_with_ttl(key, window_secs)
229    }
230}
231
232/// Minimal Memcached-like client contract.
233///
234/// # Examples
235///
236/// ```rust
237/// use http_handle::distributed_rate_limit::MemcachedClient;
238/// # let _ = std::any::TypeId::of::<&dyn MemcachedClient>();
239/// assert_eq!(1, 1);
240/// ```
241///
242/// # Panics
243///
244/// Trait usage does not panic by itself.
245pub trait MemcachedClient: Send + Sync + std::fmt::Debug {
246    /// Increments key and returns current count.
247    fn incr(
248        &self,
249        key: &str,
250        initial: u64,
251        ttl_secs: u32,
252    ) -> Result<u64, ServerError>;
253}
254
255/// Memcached backend adapter.
256///
257/// # Examples
258///
259/// ```rust
260/// use http_handle::distributed_rate_limit::{MemcachedBackend, MemcachedClient};
261/// use http_handle::ServerError;
262/// #[derive(Debug)]
263/// struct Dummy;
264/// impl MemcachedClient for Dummy {
265///     fn incr(&self, _key: &str, _initial: u64, _ttl_secs: u32) -> Result<u64, ServerError> { Ok(1) }
266/// }
267/// let _backend = MemcachedBackend::new(Dummy);
268/// assert_eq!(1, 1);
269/// ```
270///
271/// # Panics
272///
273/// This type does not panic.
274#[derive(Debug)]
275pub struct MemcachedBackend<C: MemcachedClient> {
276    client: C,
277}
278
279impl<C: MemcachedClient> MemcachedBackend<C> {
280    /// Creates a new Memcached backend adapter.
281    ///
282    /// # Examples
283    ///
284    /// ```rust
285    /// use http_handle::distributed_rate_limit::{MemcachedBackend, MemcachedClient};
286    /// use http_handle::ServerError;
287    /// #[derive(Debug)]
288    /// struct Dummy;
289    /// impl MemcachedClient for Dummy {
290    ///     fn incr(&self, _key: &str, _initial: u64, _ttl_secs: u32) -> Result<u64, ServerError> { Ok(1) }
291    /// }
292    /// let _backend = MemcachedBackend::new(Dummy);
293    /// assert_eq!(1, 1);
294    /// ```
295    ///
296    /// # Panics
297    ///
298    /// This function does not panic.
299    pub fn new(client: C) -> Self {
300        Self { client }
301    }
302}
303
304impl<C: MemcachedClient> RateLimitBackend for MemcachedBackend<C> {
305    fn increment_and_get(
306        &self,
307        key: &str,
308        window_secs: u64,
309    ) -> Result<u64, ServerError> {
310        self.client.incr(key, 1, window_secs as u32)
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317
318    #[derive(Debug, Default)]
319    struct MockRedis {
320        counts: Mutex<HashMap<String, u64>>,
321    }
322    impl RedisClient for MockRedis {
323        fn incr_with_ttl(
324            &self,
325            key: &str,
326            _ttl_secs: u64,
327        ) -> Result<u64, ServerError> {
328            let mut counts = self
329                .counts
330                .lock()
331                .map_err(|_| ServerError::Custom("poisoned".into()))?;
332            let entry = counts.entry(key.to_string()).or_insert(0);
333            *entry += 1;
334            Ok(*entry)
335        }
336    }
337
338    #[derive(Debug, Default)]
339    struct MockMemcached {
340        counts: Mutex<HashMap<String, u64>>,
341    }
342    impl MemcachedClient for MockMemcached {
343        fn incr(
344            &self,
345            key: &str,
346            initial: u64,
347            _ttl_secs: u32,
348        ) -> Result<u64, ServerError> {
349            let mut counts = self
350                .counts
351                .lock()
352                .map_err(|_| ServerError::Custom("poisoned".into()))?;
353            if let Some(entry) = counts.get_mut(key) {
354                *entry += 1;
355                Ok(*entry)
356            } else {
357                let _ = counts.insert(key.to_string(), initial);
358                Ok(initial)
359            }
360        }
361    }
362
363    #[test]
364    fn in_memory_backend_enforces_limit() {
365        let limiter = DistributedRateLimiter::new(
366            InMemoryBackend::default(),
367            "ip",
368            2,
369            60,
370        );
371        let ip: IpAddr = "127.0.0.1".parse().expect("ip");
372        assert!(!limiter.is_limited(ip).expect("limit"));
373        assert!(!limiter.is_limited(ip).expect("limit"));
374        assert!(limiter.is_limited(ip).expect("limit"));
375    }
376
377    #[test]
378    fn redis_adapter_routes_calls() {
379        let backend = RedisBackend::new(MockRedis::default());
380        let limiter = DistributedRateLimiter::new(backend, "ip", 1, 60);
381        let ip: IpAddr = "127.0.0.2".parse().expect("ip");
382        assert!(!limiter.is_limited(ip).expect("limit"));
383        assert!(limiter.is_limited(ip).expect("limit"));
384    }
385
386    #[test]
387    fn memcached_adapter_routes_calls() {
388        let backend = MemcachedBackend::new(MockMemcached::default());
389        let limiter = DistributedRateLimiter::new(backend, "ip", 1, 60);
390        let ip: IpAddr = "127.0.0.3".parse().expect("ip");
391        assert!(!limiter.is_limited(ip).expect("limit"));
392        assert!(limiter.is_limited(ip).expect("limit"));
393    }
394
395    #[test]
396    fn limiter_propagates_backend_errors() {
397        #[derive(Debug)]
398        struct FailingBackend;
399        impl RateLimitBackend for FailingBackend {
400            fn increment_and_get(
401                &self,
402                _key: &str,
403                _window_secs: u64,
404            ) -> Result<u64, ServerError> {
405                Err(ServerError::Custom("backend down".into()))
406            }
407        }
408
409        let limiter =
410            DistributedRateLimiter::new(FailingBackend, "ip", 0, 0);
411        let ip: IpAddr = "127.0.0.9".parse().expect("ip");
412        let err = limiter.is_limited(ip).expect_err("must fail");
413        assert!(err.to_string().contains("backend down"));
414    }
415
416    #[test]
417    fn in_memory_backend_maps_poisoned_lock_to_error() {
418        let backend = InMemoryBackend::default();
419        let _ = std::panic::catch_unwind(|| {
420            let _guard = backend.state.lock().expect("lock");
421            panic!("poison lock");
422        });
423        let err = backend
424            .increment_and_get("ip:127.0.0.1", 60)
425            .expect_err("poisoned lock should error");
426        assert!(err.to_string().contains("poisoned"));
427    }
428}