1use crate::error::ServerError;
12use std::collections::HashMap;
13use std::fmt;
14use std::io::{self, BufRead, BufReader};
15use std::net::TcpStream;
16use std::time::Duration;
17
18const MAX_REQUEST_LINE_LENGTH: usize = 8190;
21
22const REQUEST_PARTS: usize = 3;
24
25const TIMEOUT_SECONDS: u64 = 30;
27const MAX_HEADER_COUNT: usize = 100;
29const MAX_HEADER_LINE_LENGTH: usize = 8192;
31const MAX_HEADER_BYTES: usize = 64 * 1024;
33
34fn map_timeout_error(error: io::Error) -> ServerError {
35 ServerError::invalid_request(format!(
36 "Failed to set read timeout: {}",
37 error
38 ))
39}
40
41fn map_read_error(error: io::Error) -> ServerError {
42 ServerError::invalid_request(format!(
43 "Failed to read request line: {}",
44 error
45 ))
46}
47
48#[doc(alias = "http request")]
72#[derive(Debug, Clone, PartialEq)]
73pub struct Request {
74 pub method: String,
76 pub path: String,
78 pub version: String,
80 pub headers: HashMap<String, String>,
82}
83
84impl Request {
85 #[doc(alias = "parse")]
123 #[doc(alias = "from tcp")]
124 pub fn from_stream(
125 stream: &TcpStream,
126 ) -> Result<Self, ServerError> {
127 stream
128 .set_read_timeout(Some(Duration::from_secs(
129 TIMEOUT_SECONDS,
130 )))
131 .map_err(map_timeout_error)?;
132
133 let mut buf_reader = BufReader::new(stream);
134 let mut request_line = String::new();
135
136 let _ = buf_reader
137 .read_line(&mut request_line)
138 .map_err(map_read_error)?;
139
140 let trimmed_request_line = request_line.trim_end();
142
143 if request_line.len() > MAX_REQUEST_LINE_LENGTH {
145 return Err(ServerError::invalid_request(format!(
146 "Request line too long: {} characters (max {})",
147 request_line.len(),
148 MAX_REQUEST_LINE_LENGTH
149 )));
150 }
151
152 let mut parts = trimmed_request_line.split_whitespace();
153 let Some(method_part) = parts.next() else {
154 return Err(ServerError::invalid_request(
155 "Invalid request line: missing method",
156 ));
157 };
158 let Some(path_part) = parts.next() else {
159 return Err(ServerError::invalid_request(
160 "Invalid request line: missing path",
161 ));
162 };
163 let Some(version_part) = parts.next() else {
164 return Err(ServerError::invalid_request(
165 "Invalid request line: missing HTTP version",
166 ));
167 };
168 if parts.next().is_some() {
169 return Err(ServerError::invalid_request(format!(
170 "Invalid request line: expected {} parts",
171 REQUEST_PARTS
172 )));
173 }
174
175 let method = method_part.to_string();
176 if !Self::is_valid_method(&method) {
177 return Err(ServerError::invalid_request(format!(
178 "Invalid HTTP method: {}",
179 method
180 )));
181 }
182
183 let path = path_part.to_string();
184 let is_options_asterisk =
185 method.eq_ignore_ascii_case("OPTIONS") && path == "*";
186 if !path.starts_with('/') && !is_options_asterisk {
187 return Err(ServerError::invalid_request(
188 "Invalid path: must start with '/' (or be '*' for OPTIONS)",
189 ));
190 }
191
192 let version = version_part.to_string();
193 if !Self::is_valid_version(&version) {
194 return Err(ServerError::invalid_request(format!(
195 "Invalid HTTP version: {}",
196 version
197 )));
198 }
199
200 let headers = Self::read_headers(&mut buf_reader)?;
201
202 Ok(Request {
203 method,
204 path,
205 version,
206 headers,
207 })
208 }
209
210 pub fn method(&self) -> &str {
216 &self.method
217 }
218
219 pub fn path(&self) -> &str {
225 &self.path
226 }
227
228 pub fn version(&self) -> &str {
234 &self.version
235 }
236
237 #[doc(alias = "header lookup")]
260 pub fn header(&self, name: &str) -> Option<&str> {
261 self.headers
262 .get(&name.to_ascii_lowercase())
263 .map(String::as_str)
264 }
265
266 pub fn headers(&self) -> &HashMap<String, String> {
268 &self.headers
269 }
270
271 fn is_valid_method(method: &str) -> bool {
281 matches!(
282 method.to_ascii_uppercase().as_str(),
283 "GET"
284 | "POST"
285 | "PUT"
286 | "DELETE"
287 | "HEAD"
288 | "OPTIONS"
289 | "PATCH"
290 )
291 }
292
293 fn is_valid_version(version: &str) -> bool {
303 version.eq_ignore_ascii_case("HTTP/1.0")
304 || version.eq_ignore_ascii_case("HTTP/1.1")
305 }
306
307 fn read_headers<R: BufRead>(
308 reader: &mut R,
309 ) -> Result<HashMap<String, String>, ServerError> {
310 let mut headers = HashMap::with_capacity(16);
311 let mut total_bytes = 0_usize;
312
313 loop {
314 let mut line = String::new();
315 let bytes =
316 reader.read_line(&mut line).map_err(map_read_error)?;
317 if bytes == 0 {
318 break;
319 }
320 total_bytes = total_bytes.saturating_add(bytes);
321 if total_bytes > MAX_HEADER_BYTES {
322 return Err(ServerError::invalid_request(
323 "Header section too large",
324 ));
325 }
326
327 let trimmed = line.trim_end();
328 if trimmed.is_empty() {
329 break;
330 }
331 if trimmed.len() > MAX_HEADER_LINE_LENGTH {
332 return Err(ServerError::invalid_request(
333 "Header line too long",
334 ));
335 }
336 let (name, value) =
337 trimmed.split_once(':').ok_or_else(|| {
338 ServerError::invalid_request(
339 "Malformed header line",
340 )
341 })?;
342 if headers.len() >= MAX_HEADER_COUNT {
343 return Err(ServerError::invalid_request(
344 "Too many request headers",
345 ));
346 }
347 let _ = headers.insert(
348 name.trim().to_ascii_lowercase(),
349 value.trim().to_string(),
350 );
351 }
352
353 Ok(headers)
354 }
355}
356
357impl fmt::Display for Request {
358 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
359 write!(f, "{} {} {}", self.method, self.path, self.version)
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366 use std::io::Write;
367 use std::net::TcpListener;
368
369 #[test]
370 fn test_valid_request() {
371 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
372 let addr = listener.local_addr().unwrap();
373
374 let _ = std::thread::spawn(move || {
375 let (mut stream, _) = listener.accept().unwrap();
376 stream.write_all(b"GET /index.html HTTP/1.1\r\n").unwrap();
377 });
378
379 let stream = TcpStream::connect(addr).unwrap();
380 let request = Request::from_stream(&stream).unwrap();
381
382 assert_eq!(request.method(), "GET");
383 assert_eq!(request.path(), "/index.html");
384 assert_eq!(request.version(), "HTTP/1.1");
385 assert!(request.headers().is_empty());
386 }
387
388 #[test]
389 fn test_invalid_method() {
390 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
391 let addr = listener.local_addr().unwrap();
392
393 let _ = std::thread::spawn(move || {
394 let (mut stream, _) = listener.accept().unwrap();
395 stream
396 .write_all(b"INVALID /index.html HTTP/1.1\r\n")
397 .unwrap();
398 });
399
400 let stream = TcpStream::connect(addr).unwrap();
401 let result = Request::from_stream(&stream);
402
403 assert!(result.is_err());
404 assert!(matches!(
405 result.unwrap_err(),
406 ServerError::InvalidRequest(_)
407 ));
408 }
409
410 #[test]
411 fn test_max_length_request() {
412 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
413 let addr = listener.local_addr().unwrap();
414
415 let _ = std::thread::spawn(move || {
416 let (mut stream, _) = listener.accept().unwrap();
417 let long_path = "/".repeat(MAX_REQUEST_LINE_LENGTH - 16); let request = format!("GET {} HTTP/1.1\r\n", long_path);
419 stream.write_all(request.as_bytes()).unwrap();
420 });
421
422 let stream = TcpStream::connect(addr).unwrap();
423 let result = Request::from_stream(&stream);
424
425 assert!(result.is_ok());
426 assert_eq!(
427 result.unwrap().path().len(),
428 MAX_REQUEST_LINE_LENGTH - 16
429 );
430 }
431
432 #[test]
433 fn test_oversized_request() {
434 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
435 let addr = listener.local_addr().unwrap();
436
437 let _ = std::thread::spawn(move || {
438 let (mut stream, _) = listener.accept().unwrap();
439 let long_path = "/".repeat(MAX_REQUEST_LINE_LENGTH - 13); let request = format!("GET {} HTTP/1.1\r\n", long_path);
441 stream.write_all(request.as_bytes()).unwrap();
442 });
443
444 let stream = TcpStream::connect(addr).unwrap();
445 let result = Request::from_stream(&stream);
446
447 assert!(
448 result.is_err(),
449 "Oversized request should be invalid. Request: {:?}",
450 result
451 );
452 let msg = result.unwrap_err().to_string();
453 assert!(
454 msg.contains("Request line too long:"),
455 "Unexpected error message: {}",
456 msg
457 );
458 }
459
460 #[test]
461 fn test_invalid_path() {
462 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
463 let addr = listener.local_addr().unwrap();
464
465 let _ = std::thread::spawn(move || {
466 let (mut stream, _) = listener.accept().unwrap();
467 stream.write_all(b"GET index.html HTTP/1.1\r\n").unwrap();
468 });
469
470 let stream = TcpStream::connect(addr).unwrap();
471 let result = Request::from_stream(&stream);
472
473 assert!(result.is_err());
474 assert!(matches!(
475 result.unwrap_err(),
476 ServerError::InvalidRequest(_)
477 ));
478 }
479
480 #[test]
481 fn test_invalid_version() {
482 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
483 let addr = listener.local_addr().unwrap();
484
485 let _ = std::thread::spawn(move || {
486 let (mut stream, _) = listener.accept().unwrap();
487 stream.write_all(b"GET /index.html HTTP/2.0\r\n").unwrap();
488 });
489
490 let stream = TcpStream::connect(addr).unwrap();
491 let result = Request::from_stream(&stream);
492
493 assert!(result.is_err());
494 assert!(matches!(
495 result.unwrap_err(),
496 ServerError::InvalidRequest(_)
497 ));
498 }
499
500 #[test]
501 fn test_head_request() {
502 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
503 let addr = listener.local_addr().unwrap();
504
505 let _ = std::thread::spawn(move || {
506 let (mut stream, _) = listener.accept().unwrap();
507 stream.write_all(b"HEAD /index.html HTTP/1.1\r\n").unwrap();
508 });
509
510 let stream = TcpStream::connect(addr).unwrap();
511 let request = Request::from_stream(&stream).unwrap();
512
513 assert_eq!(request.method(), "HEAD");
514 assert_eq!(request.path(), "/index.html");
515 assert_eq!(request.version(), "HTTP/1.1");
516 }
517
518 #[test]
519 fn test_options_request() {
520 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
521 let addr = listener.local_addr().unwrap();
522
523 let _ = std::thread::spawn(move || {
524 let (mut stream, _) = listener.accept().unwrap();
525 stream.write_all(b"OPTIONS * HTTP/1.1\r\n").unwrap();
526 });
527
528 let stream = TcpStream::connect(addr).unwrap();
529 let request = Request::from_stream(&stream).unwrap();
530
531 assert_eq!(request.method(), "OPTIONS");
532 assert_eq!(request.path(), "*");
533 assert_eq!(request.version(), "HTTP/1.1");
534 }
535
536 #[test]
537 fn test_internal_error_mapping_helpers() {
538 let timeout_err =
539 io::Error::new(io::ErrorKind::TimedOut, "timeout");
540 let mapped = map_timeout_error(timeout_err);
541 assert!(
542 mapped.to_string().contains("Failed to set read timeout")
543 );
544
545 let read_err =
546 io::Error::new(io::ErrorKind::UnexpectedEof, "eof");
547 let mapped = map_read_error(read_err);
548 assert!(
549 mapped.to_string().contains("Failed to read request line")
550 );
551 }
552
553 #[test]
554 fn test_parses_headers() {
555 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
556 let addr = listener.local_addr().unwrap();
557
558 let _ = std::thread::spawn(move || {
559 let (mut stream, _) = listener.accept().unwrap();
560 stream
561 .write_all(
562 b"GET /index.html HTTP/1.1\r\nHost: localhost\r\nRange: bytes=0-1\r\n\r\n",
563 )
564 .unwrap();
565 });
566
567 let stream = TcpStream::connect(addr).unwrap();
568 let request = Request::from_stream(&stream).unwrap();
569 assert_eq!(request.header("host"), Some("localhost"));
570 assert_eq!(request.header("range"), Some("bytes=0-1"));
571 }
572}