Skip to main content

http_handle/
request.rs

1// SPDX-License-Identifier: AGPL-3.0-only
2// Copyright (c) 2026 Sebastien Rousseau
3
4// src/request.rs
5
6//! HTTP/1.x request parsing and validation.
7//!
8//! Use this module to convert raw stream input into typed request data with bounded parsing,
9//! header normalization, and explicit malformed-request errors.
10
11use crate::error::ServerError;
12use std::collections::HashMap;
13use std::fmt;
14use std::io::{self, BufRead, BufReader};
15use std::net::TcpStream;
16use std::time::Duration;
17
18/// Maximum length allowed for the request line (8KB).
19/// This includes the method, path, version, and the two spaces between them, but not the trailing \r\n.
20const MAX_REQUEST_LINE_LENGTH: usize = 8190;
21
22/// Number of parts expected in a valid HTTP request line.
23const REQUEST_PARTS: usize = 3;
24
25/// Timeout duration for reading from the TCP stream (in seconds).
26const TIMEOUT_SECONDS: u64 = 30;
27/// Maximum number of accepted request headers.
28const MAX_HEADER_COUNT: usize = 100;
29/// Maximum allowed length for a single header line.
30const MAX_HEADER_LINE_LENGTH: usize = 8192;
31/// Maximum cumulative bytes for all headers.
32const MAX_HEADER_BYTES: usize = 64 * 1024;
33
34fn map_timeout_error(error: io::Error) -> ServerError {
35    ServerError::invalid_request(format!(
36        "Failed to set read timeout: {}",
37        error
38    ))
39}
40
41fn map_read_error(error: io::Error) -> ServerError {
42    ServerError::invalid_request(format!(
43        "Failed to read request line: {}",
44        error
45    ))
46}
47
48/// Represents a parsed HTTP/1.x request line and headers.
49///
50/// You receive this type after successful stream parsing. It is the primary request model
51/// used by the synchronous server path and shared response-generation helpers.
52///
53/// # Examples
54///
55/// ```rust
56/// use http_handle::request::Request;
57/// use std::collections::HashMap;
58///
59/// let request = Request {
60///     method: "GET".to_string(),
61///     path: "/".to_string(),
62///     version: "HTTP/1.1".to_string(),
63///     headers: HashMap::new(),
64/// };
65/// assert_eq!(request.method(), "GET");
66/// ```
67///
68/// # Panics
69///
70/// This type does not panic on construction.
71#[doc(alias = "http request")]
72#[derive(Debug, Clone, PartialEq)]
73pub struct Request {
74    /// HTTP method of the request.
75    pub method: String,
76    /// Requested path.
77    pub path: String,
78    /// HTTP version of the request.
79    pub version: String,
80    /// Parsed request headers (header-name lowercased).
81    pub headers: HashMap<String, String>,
82}
83
84impl Request {
85    /// Parses a request line and headers from a `TcpStream`.
86    ///
87    /// This method reads the first line of an HTTP request from the given TCP stream,
88    /// parses it, and constructs a `Request` instance if the input is valid.
89    ///
90    /// # Arguments
91    ///
92    /// * `stream` - A reference to the `TcpStream` from which the request will be read.
93    ///
94    /// # Returns
95    ///
96    /// * `Ok(Request)` - If the request is valid and successfully parsed.
97    /// * `Err(ServerError)` - If the request is malformed, cannot be read, or is invalid.
98    ///
99    /// # Errors
100    ///
101    /// This function returns a `ServerError::InvalidRequest` error if:
102    /// - The request line is too long (exceeds `MAX_REQUEST_LINE_LENGTH`)
103    /// - The request line does not contain exactly three parts
104    /// - The HTTP method is not recognized
105    /// - The request path does not start with a forward slash (except `OPTIONS *`)
106    /// - The HTTP version is not supported (only HTTP/1.0 and HTTP/1.1 are accepted)
107    ///
108    /// # Examples
109    ///
110    /// ```rust,no_run
111    /// use std::net::TcpStream;
112    /// use http_handle::request::Request;
113    ///
114    /// let stream = TcpStream::connect("127.0.0.1:8080").expect("connect");
115    /// let parsed = Request::from_stream(&stream);
116    /// assert!(parsed.is_ok() || parsed.is_err());
117    /// ```
118    ///
119    /// # Panics
120    ///
121    /// This function does not intentionally panic.
122    #[doc(alias = "parse")]
123    #[doc(alias = "from tcp")]
124    pub fn from_stream(
125        stream: &TcpStream,
126    ) -> Result<Self, ServerError> {
127        stream
128            .set_read_timeout(Some(Duration::from_secs(
129                TIMEOUT_SECONDS,
130            )))
131            .map_err(map_timeout_error)?;
132
133        let mut buf_reader = BufReader::new(stream);
134        let mut request_line = String::new();
135
136        let _ = buf_reader
137            .read_line(&mut request_line)
138            .map_err(map_read_error)?;
139
140        // Trim the trailing \r\n before checking the length
141        let trimmed_request_line = request_line.trim_end();
142
143        // Check if the request line exceeds the maximum allowed length
144        if request_line.len() > MAX_REQUEST_LINE_LENGTH {
145            return Err(ServerError::invalid_request(format!(
146                "Request line too long: {} characters (max {})",
147                request_line.len(),
148                MAX_REQUEST_LINE_LENGTH
149            )));
150        }
151
152        let mut parts = trimmed_request_line.split_whitespace();
153        let Some(method_part) = parts.next() else {
154            return Err(ServerError::invalid_request(
155                "Invalid request line: missing method",
156            ));
157        };
158        let Some(path_part) = parts.next() else {
159            return Err(ServerError::invalid_request(
160                "Invalid request line: missing path",
161            ));
162        };
163        let Some(version_part) = parts.next() else {
164            return Err(ServerError::invalid_request(
165                "Invalid request line: missing HTTP version",
166            ));
167        };
168        if parts.next().is_some() {
169            return Err(ServerError::invalid_request(format!(
170                "Invalid request line: expected {} parts",
171                REQUEST_PARTS
172            )));
173        }
174
175        let method = method_part.to_string();
176        if !Self::is_valid_method(&method) {
177            return Err(ServerError::invalid_request(format!(
178                "Invalid HTTP method: {}",
179                method
180            )));
181        }
182
183        let path = path_part.to_string();
184        let is_options_asterisk =
185            method.eq_ignore_ascii_case("OPTIONS") && path == "*";
186        if !path.starts_with('/') && !is_options_asterisk {
187            return Err(ServerError::invalid_request(
188                "Invalid path: must start with '/' (or be '*' for OPTIONS)",
189            ));
190        }
191
192        let version = version_part.to_string();
193        if !Self::is_valid_version(&version) {
194            return Err(ServerError::invalid_request(format!(
195                "Invalid HTTP version: {}",
196                version
197            )));
198        }
199
200        let headers = Self::read_headers(&mut buf_reader)?;
201
202        Ok(Request {
203            method,
204            path,
205            version,
206            headers,
207        })
208    }
209
210    /// Returns the HTTP method of the request.
211    ///
212    /// # Returns
213    ///
214    /// A string slice containing the HTTP method (e.g., "GET", "POST").
215    pub fn method(&self) -> &str {
216        &self.method
217    }
218
219    /// Returns the requested path of the request.
220    ///
221    /// # Returns
222    ///
223    /// A string slice containing the requested path.
224    pub fn path(&self) -> &str {
225        &self.path
226    }
227
228    /// Returns the HTTP version of the request.
229    ///
230    /// # Returns
231    ///
232    /// A string slice containing the HTTP version (e.g., "HTTP/1.1").
233    pub fn version(&self) -> &str {
234        &self.version
235    }
236
237    /// Returns the value of a header by case-insensitive name.
238    ///
239    /// # Examples
240    ///
241    /// ```rust
242    /// use http_handle::request::Request;
243    /// use std::collections::HashMap;
244    ///
245    /// let mut headers = HashMap::new();
246    /// headers.insert("content-type".to_string(), "text/plain".to_string());
247    /// let request = Request {
248    ///     method: "GET".to_string(),
249    ///     path: "/".to_string(),
250    ///     version: "HTTP/1.1".to_string(),
251    ///     headers,
252    /// };
253    /// assert_eq!(request.header("Content-Type"), Some("text/plain"));
254    /// ```
255    ///
256    /// # Panics
257    ///
258    /// This function does not panic.
259    #[doc(alias = "header lookup")]
260    pub fn header(&self, name: &str) -> Option<&str> {
261        self.headers
262            .get(&name.to_ascii_lowercase())
263            .map(String::as_str)
264    }
265
266    /// Returns all parsed headers.
267    pub fn headers(&self) -> &HashMap<String, String> {
268        &self.headers
269    }
270
271    /// Checks if the given method is a valid HTTP method.
272    ///
273    /// # Arguments
274    ///
275    /// * `method` - A string slice containing the HTTP method to validate.
276    ///
277    /// # Returns
278    ///
279    /// `true` if the method is valid, `false` otherwise.
280    fn is_valid_method(method: &str) -> bool {
281        matches!(
282            method.to_ascii_uppercase().as_str(),
283            "GET"
284                | "POST"
285                | "PUT"
286                | "DELETE"
287                | "HEAD"
288                | "OPTIONS"
289                | "PATCH"
290        )
291    }
292
293    /// Checks if the given HTTP version is supported.
294    ///
295    /// # Arguments
296    ///
297    /// * `version` - A string slice containing the HTTP version to validate.
298    ///
299    /// # Returns
300    ///
301    /// `true` if the version is supported, `false` otherwise.
302    fn is_valid_version(version: &str) -> bool {
303        version.eq_ignore_ascii_case("HTTP/1.0")
304            || version.eq_ignore_ascii_case("HTTP/1.1")
305    }
306
307    fn read_headers<R: BufRead>(
308        reader: &mut R,
309    ) -> Result<HashMap<String, String>, ServerError> {
310        let mut headers = HashMap::with_capacity(16);
311        let mut total_bytes = 0_usize;
312
313        loop {
314            let mut line = String::new();
315            let bytes =
316                reader.read_line(&mut line).map_err(map_read_error)?;
317            if bytes == 0 {
318                break;
319            }
320            total_bytes = total_bytes.saturating_add(bytes);
321            if total_bytes > MAX_HEADER_BYTES {
322                return Err(ServerError::invalid_request(
323                    "Header section too large",
324                ));
325            }
326
327            let trimmed = line.trim_end();
328            if trimmed.is_empty() {
329                break;
330            }
331            if trimmed.len() > MAX_HEADER_LINE_LENGTH {
332                return Err(ServerError::invalid_request(
333                    "Header line too long",
334                ));
335            }
336            let (name, value) =
337                trimmed.split_once(':').ok_or_else(|| {
338                    ServerError::invalid_request(
339                        "Malformed header line",
340                    )
341                })?;
342            if headers.len() >= MAX_HEADER_COUNT {
343                return Err(ServerError::invalid_request(
344                    "Too many request headers",
345                ));
346            }
347            let _ = headers.insert(
348                name.trim().to_ascii_lowercase(),
349                value.trim().to_string(),
350            );
351        }
352
353        Ok(headers)
354    }
355}
356
357impl fmt::Display for Request {
358    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
359        write!(f, "{} {} {}", self.method, self.path, self.version)
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366    use std::io::Write;
367    use std::net::TcpListener;
368
369    #[test]
370    fn test_valid_request() {
371        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
372        let addr = listener.local_addr().unwrap();
373
374        let _ = std::thread::spawn(move || {
375            let (mut stream, _) = listener.accept().unwrap();
376            stream.write_all(b"GET /index.html HTTP/1.1\r\n").unwrap();
377        });
378
379        let stream = TcpStream::connect(addr).unwrap();
380        let request = Request::from_stream(&stream).unwrap();
381
382        assert_eq!(request.method(), "GET");
383        assert_eq!(request.path(), "/index.html");
384        assert_eq!(request.version(), "HTTP/1.1");
385        assert!(request.headers().is_empty());
386    }
387
388    #[test]
389    fn test_invalid_method() {
390        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
391        let addr = listener.local_addr().unwrap();
392
393        let _ = std::thread::spawn(move || {
394            let (mut stream, _) = listener.accept().unwrap();
395            stream
396                .write_all(b"INVALID /index.html HTTP/1.1\r\n")
397                .unwrap();
398        });
399
400        let stream = TcpStream::connect(addr).unwrap();
401        let result = Request::from_stream(&stream);
402
403        assert!(result.is_err());
404        assert!(matches!(
405            result.unwrap_err(),
406            ServerError::InvalidRequest(_)
407        ));
408    }
409
410    #[test]
411    fn test_max_length_request() {
412        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
413        let addr = listener.local_addr().unwrap();
414
415        let _ = std::thread::spawn(move || {
416            let (mut stream, _) = listener.accept().unwrap();
417            let long_path = "/".repeat(MAX_REQUEST_LINE_LENGTH - 16); // Account for "GET ", " HTTP/1.1", and "\r\n"
418            let request = format!("GET {} HTTP/1.1\r\n", long_path);
419            stream.write_all(request.as_bytes()).unwrap();
420        });
421
422        let stream = TcpStream::connect(addr).unwrap();
423        let result = Request::from_stream(&stream);
424
425        assert!(result.is_ok());
426        assert_eq!(
427            result.unwrap().path().len(),
428            MAX_REQUEST_LINE_LENGTH - 16
429        );
430    }
431
432    #[test]
433    fn test_oversized_request() {
434        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
435        let addr = listener.local_addr().unwrap();
436
437        let _ = std::thread::spawn(move || {
438            let (mut stream, _) = listener.accept().unwrap();
439            let long_path = "/".repeat(MAX_REQUEST_LINE_LENGTH - 13); // 13 = len("GET  HTTP/1.1")
440            let request = format!("GET {} HTTP/1.1\r\n", long_path);
441            stream.write_all(request.as_bytes()).unwrap();
442        });
443
444        let stream = TcpStream::connect(addr).unwrap();
445        let result = Request::from_stream(&stream);
446
447        assert!(
448            result.is_err(),
449            "Oversized request should be invalid. Request: {:?}",
450            result
451        );
452        let msg = result.unwrap_err().to_string();
453        assert!(
454            msg.contains("Request line too long:"),
455            "Unexpected error message: {}",
456            msg
457        );
458    }
459
460    #[test]
461    fn test_invalid_path() {
462        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
463        let addr = listener.local_addr().unwrap();
464
465        let _ = std::thread::spawn(move || {
466            let (mut stream, _) = listener.accept().unwrap();
467            stream.write_all(b"GET index.html HTTP/1.1\r\n").unwrap();
468        });
469
470        let stream = TcpStream::connect(addr).unwrap();
471        let result = Request::from_stream(&stream);
472
473        assert!(result.is_err());
474        assert!(matches!(
475            result.unwrap_err(),
476            ServerError::InvalidRequest(_)
477        ));
478    }
479
480    #[test]
481    fn test_invalid_version() {
482        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
483        let addr = listener.local_addr().unwrap();
484
485        let _ = std::thread::spawn(move || {
486            let (mut stream, _) = listener.accept().unwrap();
487            stream.write_all(b"GET /index.html HTTP/2.0\r\n").unwrap();
488        });
489
490        let stream = TcpStream::connect(addr).unwrap();
491        let result = Request::from_stream(&stream);
492
493        assert!(result.is_err());
494        assert!(matches!(
495            result.unwrap_err(),
496            ServerError::InvalidRequest(_)
497        ));
498    }
499
500    #[test]
501    fn test_head_request() {
502        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
503        let addr = listener.local_addr().unwrap();
504
505        let _ = std::thread::spawn(move || {
506            let (mut stream, _) = listener.accept().unwrap();
507            stream.write_all(b"HEAD /index.html HTTP/1.1\r\n").unwrap();
508        });
509
510        let stream = TcpStream::connect(addr).unwrap();
511        let request = Request::from_stream(&stream).unwrap();
512
513        assert_eq!(request.method(), "HEAD");
514        assert_eq!(request.path(), "/index.html");
515        assert_eq!(request.version(), "HTTP/1.1");
516    }
517
518    #[test]
519    fn test_options_request() {
520        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
521        let addr = listener.local_addr().unwrap();
522
523        let _ = std::thread::spawn(move || {
524            let (mut stream, _) = listener.accept().unwrap();
525            stream.write_all(b"OPTIONS * HTTP/1.1\r\n").unwrap();
526        });
527
528        let stream = TcpStream::connect(addr).unwrap();
529        let request = Request::from_stream(&stream).unwrap();
530
531        assert_eq!(request.method(), "OPTIONS");
532        assert_eq!(request.path(), "*");
533        assert_eq!(request.version(), "HTTP/1.1");
534    }
535
536    #[test]
537    fn test_internal_error_mapping_helpers() {
538        let timeout_err =
539            io::Error::new(io::ErrorKind::TimedOut, "timeout");
540        let mapped = map_timeout_error(timeout_err);
541        assert!(
542            mapped.to_string().contains("Failed to set read timeout")
543        );
544
545        let read_err =
546            io::Error::new(io::ErrorKind::UnexpectedEof, "eof");
547        let mapped = map_read_error(read_err);
548        assert!(
549            mapped.to_string().contains("Failed to read request line")
550        );
551    }
552
553    #[test]
554    fn test_parses_headers() {
555        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
556        let addr = listener.local_addr().unwrap();
557
558        let _ = std::thread::spawn(move || {
559            let (mut stream, _) = listener.accept().unwrap();
560            stream
561                .write_all(
562                    b"GET /index.html HTTP/1.1\r\nHost: localhost\r\nRange: bytes=0-1\r\n\r\n",
563                )
564                .unwrap();
565        });
566
567        let stream = TcpStream::connect(addr).unwrap();
568        let request = Request::from_stream(&stream).unwrap();
569        assert_eq!(request.header("host"), Some("localhost"));
570        assert_eq!(request.header("range"), Some("bytes=0-1"));
571    }
572}