1use crate::error::ServerError;
15use crate::request::Request;
16use crate::response::Response;
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use std::fs;
20use std::io;
21use std::net::{IpAddr, TcpListener, TcpStream};
22use std::path::{Path, PathBuf};
23use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
24use std::sync::mpsc::{self, Receiver, Sender};
25use std::sync::{Arc, Mutex, Once, OnceLock};
26use std::thread;
27use std::time::{Duration, Instant, UNIX_EPOCH};
28
29static SHUTDOWN_SIGNAL_SLOT: OnceLock<
30 Mutex<Option<Arc<ShutdownSignal>>>,
31> = OnceLock::new();
32static SIGNAL_HANDLER_INSTALL: Once = Once::new();
33static RATE_LIMIT_STATE: OnceLock<
34 Mutex<HashMap<IpAddr, Vec<Instant>>>,
35> = OnceLock::new();
36static METRIC_REQUESTS_TOTAL: AtomicUsize = AtomicUsize::new(0);
37static METRIC_RESPONSES_4XX: AtomicUsize = AtomicUsize::new(0);
38static METRIC_RESPONSES_5XX: AtomicUsize = AtomicUsize::new(0);
39static METRIC_RATE_LIMITED: AtomicUsize = AtomicUsize::new(0);
40
41#[doc(alias = "http server")]
63#[doc(alias = "static file server")]
64#[derive(
65 Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize,
66)]
67pub struct Server {
68 address: String,
69 document_root: PathBuf,
70 cors_enabled: Option<bool>,
71 cors_origins: Option<Vec<String>>,
72 custom_headers: Option<HashMap<String, String>>,
73 request_timeout: Option<Duration>,
74 connection_timeout: Option<Duration>,
75 rate_limit_per_minute: Option<usize>,
76 static_cache_ttl_secs: Option<u64>,
77}
78
79#[doc(alias = "builder")]
107#[doc(alias = "configuration")]
108#[derive(Clone, Debug, Default)]
109pub struct ServerBuilder {
110 address: Option<String>,
111 document_root: Option<PathBuf>,
112 cors_enabled: Option<bool>,
113 cors_origins: Option<Vec<String>>,
114 custom_headers: Option<HashMap<String, String>>,
115 request_timeout: Option<Duration>,
116 connection_timeout: Option<Duration>,
117 rate_limit_per_minute: Option<usize>,
118 static_cache_ttl_secs: Option<u64>,
119}
120
121impl ServerBuilder {
122 #[doc(alias = "new builder")]
138 pub fn new() -> Self {
139 Self::default()
140 }
141
142 #[doc(alias = "bind address")]
161 pub fn address(mut self, address: &str) -> Self {
162 self.address = Some(address.to_string());
163 self
164 }
165
166 #[doc(alias = "document root")]
185 pub fn document_root(mut self, path: &str) -> Self {
186 self.document_root = Some(PathBuf::from(path));
187 self
188 }
189
190 pub fn enable_cors(mut self) -> Self {
192 self.cors_enabled = Some(true);
193 self
194 }
195
196 pub fn disable_cors(mut self) -> Self {
198 self.cors_enabled = Some(false);
199 self
200 }
201
202 pub fn cors_origins(mut self, origins: Vec<String>) -> Self {
204 self.cors_origins = Some(origins);
205 self.cors_enabled = Some(true); self
207 }
208
209 pub fn custom_header(mut self, name: &str, value: &str) -> Self {
211 let mut headers = self.custom_headers.unwrap_or_default();
212 let _ = headers.insert(name.to_string(), value.to_string());
213 self.custom_headers = Some(headers);
214 self
215 }
216
217 pub fn custom_headers(
219 mut self,
220 headers: HashMap<String, String>,
221 ) -> Self {
222 self.custom_headers = Some(headers);
223 self
224 }
225
226 pub fn request_timeout(mut self, timeout: Duration) -> Self {
228 self.request_timeout = Some(timeout);
229 self
230 }
231
232 pub fn connection_timeout(mut self, timeout: Duration) -> Self {
234 self.connection_timeout = Some(timeout);
235 self
236 }
237
238 pub fn rate_limit_per_minute(mut self, requests: usize) -> Self {
240 self.rate_limit_per_minute = Some(requests.max(1));
241 self
242 }
243
244 pub fn static_cache_ttl_secs(mut self, ttl: u64) -> Self {
246 self.static_cache_ttl_secs = Some(ttl);
247 self
248 }
249
250 #[doc(alias = "finalize")]
274 pub fn build(self) -> Result<Server, &'static str> {
275 let address = self.address.ok_or("Address is required")?;
276 let document_root =
277 self.document_root.ok_or("Document root is required")?;
278
279 Ok(Server {
280 address,
281 document_root,
282 cors_enabled: self.cors_enabled,
283 cors_origins: self.cors_origins,
284 custom_headers: self.custom_headers,
285 request_timeout: self.request_timeout,
286 connection_timeout: self.connection_timeout,
287 rate_limit_per_minute: self.rate_limit_per_minute,
288 static_cache_ttl_secs: self.static_cache_ttl_secs,
289 })
290 }
291}
292
293#[derive(Debug, Clone)]
295pub struct ShutdownSignal {
296 pub should_shutdown: Arc<AtomicBool>,
298 pub active_connections: Arc<AtomicUsize>,
300 pub shutdown_timeout: Duration,
302}
303
304impl Default for ShutdownSignal {
305 fn default() -> Self {
306 Self::new(Duration::from_secs(30))
307 }
308}
309
310impl ShutdownSignal {
311 pub fn new(shutdown_timeout: Duration) -> Self {
313 Self {
314 should_shutdown: Arc::new(AtomicBool::new(false)),
315 active_connections: Arc::new(AtomicUsize::new(0)),
316 shutdown_timeout,
317 }
318 }
319
320 pub fn shutdown(&self) {
322 self.should_shutdown.store(true, Ordering::SeqCst);
323 println!(
324 "🛑 Shutdown signal received. Waiting for active connections to finish..."
325 );
326 }
327
328 pub fn is_shutdown_requested(&self) -> bool {
330 self.should_shutdown.load(Ordering::SeqCst)
331 }
332
333 pub fn connection_started(&self) {
335 let _ = self.active_connections.fetch_add(1, Ordering::SeqCst);
336 }
337
338 pub fn connection_finished(&self) {
340 let _ = self.active_connections.fetch_sub(1, Ordering::SeqCst);
341 }
342
343 pub fn active_connection_count(&self) -> usize {
345 self.active_connections.load(Ordering::SeqCst)
346 }
347
348 pub fn wait_for_shutdown(&self) -> bool {
350 let start_time = Instant::now();
351
352 while self.active_connection_count() > 0
353 && start_time.elapsed() < self.shutdown_timeout
354 {
355 let remaining = self
356 .shutdown_timeout
357 .saturating_sub(start_time.elapsed());
358 println!(
359 "⏳ Waiting for {} active connection(s) to finish... ({:.1}s remaining)",
360 self.active_connection_count(),
361 remaining.as_secs_f32()
362 );
363
364 thread::sleep(remaining.min(Duration::from_millis(50)));
366 }
367
368 let remaining_connections = self.active_connection_count();
369 if remaining_connections > 0 {
370 println!(
371 "⚠️ Shutdown timeout reached. {} connection(s) will be forcibly terminated.",
372 remaining_connections
373 );
374 false
375 } else {
376 println!("✅ All connections closed gracefully.");
377 true
378 }
379 }
380}
381
382pub struct ThreadPool {
384 workers: Vec<Worker>,
385 sender: Sender<Job>,
386}
387
388impl std::fmt::Debug for ThreadPool {
389 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
390 f.debug_struct("ThreadPool")
391 .field("workers", &self.workers)
392 .field("sender", &"<Sender<Job>>")
393 .finish()
394 }
395}
396
397type Job = Box<dyn FnOnce() + Send + 'static>;
399
400#[derive(Debug)]
402struct Worker {
403 id: usize,
404 thread: Option<thread::JoinHandle<()>>,
405}
406
407impl ThreadPool {
408 pub fn new(size: usize) -> ThreadPool {
416 assert!(size > 0);
417
418 let (sender, receiver) = mpsc::channel();
419 let receiver = Arc::new(Mutex::new(receiver));
420
421 let mut workers = Vec::with_capacity(size);
422
423 for id in 0..size {
424 workers.push(Worker::new(id, Arc::clone(&receiver)));
425 }
426
427 ThreadPool { workers, sender }
429 }
430
431 pub fn execute<F>(&self, f: F)
436 where
437 F: FnOnce() + Send + 'static,
438 {
439 let job = Box::new(f);
440 self.sender.send(job).unwrap();
441 }
442}
443
444impl Drop for ThreadPool {
445 fn drop(&mut self) {
446 let (replacement_sender, _replacement_receiver) =
448 mpsc::channel();
449 let old_sender =
450 std::mem::replace(&mut self.sender, replacement_sender);
451 drop(old_sender);
452
453 for worker in &mut self.workers {
454 println!("Shutting down worker {}", worker.id);
455
456 if let Some(thread) = worker.thread.take() {
457 thread.join().unwrap();
458 }
459 }
460 }
461}
462
463impl Worker {
464 fn new(id: usize, receiver: Arc<Mutex<Receiver<Job>>>) -> Worker {
465 let thread = thread::spawn(move || {
466 loop {
467 let job = receiver.lock().unwrap().recv();
468
469 match job {
470 Ok(job) => {
471 job();
472 }
473 Err(_) => {
474 println!(
475 "Worker {} disconnected; shutting down.",
476 id
477 );
478 break;
479 }
480 }
481 }
482 });
483
484 Worker {
485 id,
486 thread: Some(thread),
487 }
488 }
489}
490
491#[derive(Debug)]
493pub struct ConnectionPool {
494 max_connections: usize,
495 active_connections: Arc<AtomicUsize>,
496}
497
498impl ConnectionPool {
499 pub fn new(max_connections: usize) -> Self {
501 Self {
503 max_connections,
504 active_connections: Arc::new(AtomicUsize::new(0)),
505 }
506 }
507
508 pub fn acquire(&self) -> Result<ConnectionGuard, io::Error> {
510 #[allow(deprecated_in_future)]
511 let reserved = self.active_connections.fetch_update(
512 Ordering::SeqCst,
513 Ordering::SeqCst,
514 |current| {
515 if current < self.max_connections {
516 Some(current + 1)
517 } else {
518 None
519 }
520 },
521 );
522 if reserved.is_err() {
523 return Err(io::Error::new(
524 io::ErrorKind::WouldBlock,
525 "Connection pool exhausted",
526 ));
527 }
528 Ok(ConnectionGuard {
529 pool: Arc::clone(&self.active_connections),
530 })
531 }
532
533 pub fn active_count(&self) -> usize {
535 self.active_connections.load(Ordering::SeqCst)
536 }
537}
538
539#[derive(Debug)]
541pub struct ConnectionGuard {
542 pool: Arc<AtomicUsize>,
543}
544
545impl Drop for ConnectionGuard {
546 fn drop(&mut self) {
547 let _ = self.pool.fetch_sub(1, Ordering::SeqCst);
548 }
549}
550
551impl Server {
552 #[doc(alias = "constructor")]
570 pub fn new(address: &str, document_root: &str) -> Self {
571 Server {
572 address: address.to_string(),
573 document_root: PathBuf::from(document_root),
574 cors_enabled: None,
575 cors_origins: None,
576 custom_headers: None,
577 request_timeout: None,
578 connection_timeout: None,
579 rate_limit_per_minute: None,
580 static_cache_ttl_secs: None,
581 }
582 }
583
584 pub fn builder() -> ServerBuilder {
603 ServerBuilder::new()
604 }
605
606 #[doc(alias = "listen")]
628 #[doc(alias = "serve")]
629 pub fn start(&self) -> io::Result<()> {
630 let listener = TcpListener::bind(&self.address)?;
631 println!("❯ Server is now running at http://{}", self.address);
632 println!(" Document root: {}", self.document_root.display());
633 println!(" Press Ctrl+C to stop the server.");
634
635 Self::run_basic_accept_loop(listener.incoming(), self.clone());
636
637 Ok(())
638 }
639
640 #[doc(alias = "graceful shutdown")]
663 pub fn start_with_graceful_shutdown(
664 &self,
665 shutdown_timeout: Duration,
666 ) -> io::Result<()> {
667 let shutdown = Arc::new(ShutdownSignal::new(shutdown_timeout));
668 self.start_with_shutdown_signal(shutdown)
669 }
670
671 #[doc(alias = "shutdown signal")]
693 pub fn start_with_shutdown_signal(
694 &self,
695 shutdown: Arc<ShutdownSignal>,
696 ) -> io::Result<()> {
697 self.start_with_shutdown_signal_and_ready(shutdown, |_| {})
698 }
699
700 pub fn start_with_shutdown_signal_and_ready<F>(
714 &self,
715 shutdown: Arc<ShutdownSignal>,
716 on_ready: F,
717 ) -> io::Result<()>
718 where
719 F: FnOnce(String),
720 {
721 Self::install_signal_handlers(shutdown.clone());
723
724 let listener = TcpListener::bind(&self.address)?;
725 let bound_address = listener.local_addr()?.to_string();
726 on_ready(bound_address.clone());
727 println!("❯ Server is now running at http://{}", bound_address);
728 println!(" Document root: {}", self.document_root.display());
729 println!(" Press Ctrl+C to stop the server gracefully.");
730
731 listener.set_nonblocking(true)?;
733
734 loop {
735 if shutdown.is_shutdown_requested() {
737 println!(
738 "🛑 Shutdown requested. Stopping new connections..."
739 );
740 break;
741 }
742
743 match listener.accept() {
744 Ok((stream, _addr)) => Self::run_tracked_accept(
745 stream,
746 self.clone(),
747 shutdown.clone(),
748 ),
749 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
750 thread::sleep(Duration::from_millis(100));
752 }
753 Err(e) => Self::log_listener_error(e),
754 }
755 }
756
757 let graceful = shutdown.wait_for_shutdown();
759
760 if graceful {
761 println!("✅ Server shut down gracefully.");
762 } else {
763 println!(
764 "⚠️ Server shut down with active connections remaining."
765 );
766 }
767
768 Ok(())
769 }
770
771 fn install_signal_handlers(shutdown: Arc<ShutdownSignal>) {
777 let slot =
778 SHUTDOWN_SIGNAL_SLOT.get_or_init(|| Mutex::new(None));
779
780 if let Ok(mut guard) = slot.lock() {
782 *guard = Some(shutdown);
783 }
784
785 SIGNAL_HANDLER_INSTALL.call_once(|| {
787 let _ = ctrlc::set_handler(Self::handle_shutdown_signal);
788 });
789 }
790
791 fn handle_shutdown_signal() {
792 if let Some(slot) = SHUTDOWN_SIGNAL_SLOT.get() {
793 Self::trigger_shutdown_from_slot(slot);
794 }
795 }
796
797 fn trigger_shutdown_from_slot(
798 slot: &Mutex<Option<Arc<ShutdownSignal>>>,
799 ) {
800 if let Ok(guard) = slot.lock()
801 && let Some(shutdown_signal) = guard.as_ref()
802 {
803 shutdown_signal.shutdown();
804 }
805 }
806
807 pub fn start_with_thread_pool(
820 &self,
821 thread_pool_size: usize,
822 ) -> io::Result<()> {
823 let thread_pool = ThreadPool::new(thread_pool_size);
824 let listener = TcpListener::bind(&self.address)?;
825
826 println!("❯ Server is now running at http://{}", self.address);
827 println!(" Document root: {}", self.document_root.display());
828 println!(" Thread pool size: {} workers", thread_pool_size);
829 println!(" Press Ctrl+C to stop the server.");
830
831 Self::run_thread_pool_accept_loop(
832 listener.incoming(),
833 self.clone(),
834 &thread_pool,
835 );
836
837 Ok(())
838 }
839
840 pub fn start_with_pooling(
856 &self,
857 thread_pool_size: usize,
858 max_connections: usize,
859 ) -> io::Result<()> {
860 let thread_pool = ThreadPool::new(thread_pool_size);
861 let connection_pool =
862 Arc::new(ConnectionPool::new(max_connections));
863 let listener = TcpListener::bind(&self.address)?;
864
865 println!("❯ Server is now running at http://{}", self.address);
866 println!(" Document root: {}", self.document_root.display());
867 println!(" Thread pool size: {} workers", thread_pool_size);
868 println!(" Max concurrent connections: {}", max_connections);
869 println!(" Press Ctrl+C to stop the server.");
870
871 Self::run_pooling_accept_loop(
872 listener.incoming(),
873 self.clone(),
874 &thread_pool,
875 connection_pool,
876 );
877
878 Ok(())
879 }
880
881 fn log_connection_result(result: Result<(), ServerError>) {
882 if let Err(error) = result {
883 eprintln!("Error handling connection: {}", error);
884 }
885 }
886
887 fn log_listener_error(error: io::Error) {
888 eprintln!("Connection error: {}", error);
889 }
890
891 fn run_tracked_accept(
892 stream: TcpStream,
893 server: Server,
894 shutdown: Arc<ShutdownSignal>,
895 ) {
896 shutdown.connection_started();
897 let _ = thread::spawn(move || {
898 let result =
899 handle_connection_tracked(stream, &server, &shutdown);
900 shutdown.connection_finished();
901 Self::log_connection_result(result);
902 });
903 }
904
905 fn run_basic_accept_loop<I>(incoming: I, server: Server)
906 where
907 I: IntoIterator<Item = io::Result<TcpStream>>,
908 {
909 for stream in incoming {
910 match stream {
911 Ok(stream) => {
912 let server = server.clone();
913 let _ = thread::spawn(move || {
914 Self::log_connection_result(handle_connection(
915 stream, &server,
916 ));
917 });
918 }
919 Err(error) => Self::log_listener_error(error),
920 }
921 }
922 }
923
924 fn run_thread_pool_accept_loop<I>(
925 incoming: I,
926 server: Server,
927 thread_pool: &ThreadPool,
928 ) where
929 I: IntoIterator<Item = io::Result<TcpStream>>,
930 {
931 for stream in incoming {
932 match stream {
933 Ok(stream) => {
934 let server = server.clone();
935 thread_pool.execute(move || {
936 Self::log_connection_result(handle_connection(
937 stream, &server,
938 ));
939 });
940 }
941 Err(error) => Self::log_listener_error(error),
942 }
943 }
944 }
945
946 fn run_pooling_accept_loop<I>(
947 incoming: I,
948 server: Server,
949 thread_pool: &ThreadPool,
950 connection_pool: Arc<ConnectionPool>,
951 ) where
952 I: IntoIterator<Item = io::Result<TcpStream>>,
953 {
954 for stream in incoming {
955 match stream {
956 Ok(stream) => {
957 let server = server.clone();
958 let pool_clone = Arc::clone(&connection_pool);
959 thread_pool.execute(move || match pool_clone.acquire() {
960 Ok(_guard) => Self::log_connection_result(
961 handle_connection(stream, &server),
962 ),
963 Err(_) => {
964 if let Err(error) =
965 send_service_unavailable(stream)
966 {
967 eprintln!(
968 "Error sending service unavailable: {}",
969 error
970 );
971 }
972 }
973 });
974 }
975 Err(error) => Self::log_listener_error(error),
976 }
977 }
978 }
979
980 pub fn cors_enabled(&self) -> Option<bool> {
984 self.cors_enabled
985 }
986
987 pub fn cors_origins(&self) -> &Option<Vec<String>> {
989 &self.cors_origins
990 }
991
992 pub fn custom_headers(&self) -> &Option<HashMap<String, String>> {
994 &self.custom_headers
995 }
996
997 pub fn request_timeout(&self) -> Option<Duration> {
999 self.request_timeout
1000 }
1001
1002 pub fn connection_timeout(&self) -> Option<Duration> {
1004 self.connection_timeout
1005 }
1006
1007 pub fn address(&self) -> &str {
1009 &self.address
1010 }
1011
1012 pub fn document_root(&self) -> &PathBuf {
1014 &self.document_root
1015 }
1016}
1017
1018fn send_service_unavailable(mut stream: TcpStream) -> io::Result<()> {
1028 let mut response = Response::new(
1029 503,
1030 "SERVICE UNAVAILABLE",
1031 b"Service temporarily unavailable. Please try again later."
1032 .to_vec(),
1033 );
1034
1035 response.add_header("Content-Type", "text/plain");
1036 response.add_header("Retry-After", "1"); response.add_header("Connection", "close");
1038
1039 response.send(&mut stream).map_err(|e| {
1040 use std::io::Error;
1041 Error::other(format!("Failed to send response: {}", e))
1042 })?;
1043 Ok(())
1044}
1045
1046pub(crate) fn handle_connection(
1057 mut stream: TcpStream,
1058 server: &Server,
1059) -> Result<(), ServerError> {
1060 let timeout =
1061 server.request_timeout.unwrap_or(Duration::from_secs(30));
1062 stream.set_read_timeout(Some(timeout))?;
1063 stream.set_write_timeout(Some(timeout))?;
1064
1065 let peer_ip = stream.peer_addr().ok().map(|addr| addr.ip());
1066 let response = build_response_for_stream(server, &stream, peer_ip);
1067 response.send(&mut stream)?;
1068 Ok(())
1069}
1070
1071fn handle_connection_tracked(
1086 mut stream: TcpStream,
1087 server: &Server,
1088 _shutdown: &ShutdownSignal,
1089) -> Result<(), ServerError> {
1090 stream.set_nonblocking(false)?;
1092
1093 let timeout =
1095 server.connection_timeout.unwrap_or(Duration::from_secs(30));
1096 stream.set_read_timeout(Some(timeout))?;
1097 stream.set_write_timeout(Some(timeout))?;
1098
1099 let peer_ip = stream.peer_addr().ok().map(|addr| addr.ip());
1100 let response = build_response_for_stream(server, &stream, peer_ip);
1101 response.send(&mut stream)?;
1102 Ok(())
1103}
1104
1105fn build_response_for_stream(
1106 server: &Server,
1107 stream: &TcpStream,
1108 peer_ip: Option<IpAddr>,
1109) -> Response {
1110 match Request::from_stream(stream) {
1111 Ok(request) => {
1112 if request.path() == "/metrics" && request.method() == "GET"
1113 {
1114 return generate_metrics_response();
1115 }
1116 if let Some(ip) = peer_ip
1117 && is_rate_limited(server, ip)
1118 {
1119 let _ =
1120 METRIC_RATE_LIMITED.fetch_add(1, Ordering::Relaxed);
1121 return generate_too_many_requests_response();
1122 }
1123 build_response_for_request_with_metrics(server, &request)
1124 }
1125 Err(error) => {
1126 response_from_error(&error, &server.document_root)
1127 }
1128 }
1129}
1130
1131pub(crate) fn build_response_for_request_with_metrics(
1136 server: &Server,
1137 request: &Request,
1138) -> Response {
1139 let response = build_response_for_request(server, request);
1140 record_metrics(&response);
1141 response
1142}
1143
1144pub(crate) fn build_response_for_request(
1146 server: &Server,
1147 request: &Request,
1148) -> Response {
1149 let generated = match request.method() {
1150 "GET" => generate_response_with_cache(
1151 request,
1152 &server.document_root,
1153 server.static_cache_ttl_secs,
1154 ),
1155 "HEAD" => {
1156 generate_head_response(request, &server.document_root)
1157 }
1158 "OPTIONS" => generate_options_response(request),
1159 _ => Ok(generate_method_not_allowed_response()),
1160 };
1161 match generated {
1162 Ok(response) => {
1163 apply_response_policies(response, server, request)
1164 }
1165 Err(error) => {
1166 response_from_error(&error, &server.document_root)
1167 }
1168 }
1169}
1170
1171fn record_metrics(response: &Response) {
1172 let _ = METRIC_REQUESTS_TOTAL.fetch_add(1, Ordering::Relaxed);
1173 if (400..500).contains(&response.status_code) {
1174 let _ = METRIC_RESPONSES_4XX.fetch_add(1, Ordering::Relaxed);
1175 } else if response.status_code >= 500 {
1176 let _ = METRIC_RESPONSES_5XX.fetch_add(1, Ordering::Relaxed);
1177 }
1178}
1179
1180fn generate_metrics_response() -> Response {
1181 let body = format!(
1182 "http_handle_requests_total {}\nhttp_handle_responses_4xx_total {}\nhttp_handle_responses_5xx_total {}\nhttp_handle_rate_limited_total {}\n",
1183 METRIC_REQUESTS_TOTAL.load(Ordering::Relaxed),
1184 METRIC_RESPONSES_4XX.load(Ordering::Relaxed),
1185 METRIC_RESPONSES_5XX.load(Ordering::Relaxed),
1186 METRIC_RATE_LIMITED.load(Ordering::Relaxed),
1187 );
1188 let mut response = Response::new(200, "OK", body.into_bytes());
1189 response.add_header("Content-Type", "text/plain; version=0.0.3");
1190 response
1191}
1192
1193fn generate_too_many_requests_response() -> Response {
1194 let mut response = Response::new(
1195 429,
1196 "TOO MANY REQUESTS",
1197 b"Rate limit exceeded".to_vec(),
1198 );
1199 response.add_header("Content-Type", "text/plain");
1200 response.add_header("Retry-After", "60");
1201 response
1202}
1203
1204fn is_rate_limited(server: &Server, ip: IpAddr) -> bool {
1205 let Some(limit) = server.rate_limit_per_minute else {
1206 return false;
1207 };
1208 let now = Instant::now();
1209 let state =
1210 RATE_LIMIT_STATE.get_or_init(|| Mutex::new(HashMap::new()));
1211 let mut guard = match state.lock() {
1212 Ok(guard) => guard,
1213 Err(poisoned) => poisoned.into_inner(),
1214 };
1215 let hits = guard.entry(ip).or_default();
1216 hits.retain(|timestamp| {
1217 now.duration_since(*timestamp) <= Duration::from_secs(60)
1218 });
1219 if hits.len() >= limit {
1220 return true;
1221 }
1222 hits.push(now);
1223 false
1224}
1225
1226fn generate_response(
1237 request: &Request,
1238 document_root: &Path,
1239) -> Result<Response, ServerError> {
1240 generate_response_with_cache(request, document_root, None)
1241}
1242
1243fn generate_response_with_cache(
1244 request: &Request,
1245 document_root: &Path,
1246 cache_ttl_secs: Option<u64>,
1247) -> Result<Response, ServerError> {
1248 let canonical_root = fs::canonicalize(document_root)
1249 .unwrap_or_else(|_| document_root.to_path_buf());
1250 let mut path = PathBuf::from(document_root);
1251 let request_path = request.path().trim_start_matches('/');
1252
1253 if request_path.is_empty() {
1254 path.push("index.html");
1256 } else {
1257 for component in request_path.split('/') {
1258 if component == ".." {
1259 let _ = path.pop();
1260 } else {
1261 path.push(component);
1262 }
1263 }
1264 }
1265
1266 let within_root = fs::canonicalize(&path)
1267 .map(|candidate| candidate.starts_with(&canonical_root))
1268 .unwrap_or_else(|_| path.starts_with(document_root));
1269 if !within_root {
1270 return Err(ServerError::forbidden("Access denied"));
1271 }
1272
1273 if path.is_file() {
1274 serve_file_response(request, &path, cache_ttl_secs)
1275 } else if path.is_dir() {
1276 path.push("index.html");
1278 if path.is_file() {
1279 serve_file_response(request, &path, cache_ttl_secs)
1280 } else {
1281 generate_404_response(document_root)
1282 }
1283 } else {
1284 generate_404_response(document_root)
1285 }
1286}
1287
1288fn serve_file_response(
1289 request: &Request,
1290 path: &Path,
1291 cache_ttl_secs: Option<u64>,
1292) -> Result<Response, ServerError> {
1293 let mut serving_path = path.to_path_buf();
1294 let mut content_encoding: Option<&'static str> = None;
1295 if let Some(encoding) = request.header("accept-encoding") {
1296 if encoding.contains("br") {
1297 let candidate =
1298 PathBuf::from(format!("{}.br", path.display()));
1299 if candidate.is_file() {
1300 serving_path = candidate;
1301 content_encoding = Some("br");
1302 }
1303 }
1304 if content_encoding.is_none()
1305 && (encoding.contains("zstd") || encoding.contains("zst"))
1306 {
1307 let candidate =
1308 PathBuf::from(format!("{}.zst", path.display()));
1309 if candidate.is_file() {
1310 serving_path = candidate;
1311 content_encoding = Some("zstd");
1312 }
1313 }
1314 if content_encoding.is_none() && encoding.contains("gzip") {
1315 let candidate =
1316 PathBuf::from(format!("{}.gz", path.display()));
1317 if candidate.is_file() {
1318 serving_path = candidate;
1319 content_encoding = Some("gzip");
1320 }
1321 }
1322 }
1323
1324 let contents = fs::read(&serving_path)?;
1325 let metadata = fs::metadata(path)?;
1326 let etag = compute_etag(&metadata);
1327 if request
1328 .header("if-none-match")
1329 .is_some_and(|candidate| candidate == etag)
1330 {
1331 let mut response =
1332 Response::new(304, "NOT MODIFIED", Vec::new());
1333 response.add_header("ETag", &etag);
1334 return Ok(response);
1335 }
1336
1337 let content_type = get_content_type(path);
1338 let mut response = if let Some((start, end)) =
1339 parse_range_header(request.header("range"), contents.len())
1340 {
1341 let body = contents[start..=end].to_vec();
1342 let mut partial = Response::new(206, "PARTIAL CONTENT", body);
1343 partial.add_header(
1344 "Content-Range",
1345 &format!("bytes {}-{}/{}", start, end, contents.len()),
1346 );
1347 partial
1348 } else {
1349 Response::new(200, "OK", contents)
1350 };
1351
1352 response.add_header("Content-Type", content_type);
1353 response.add_header("ETag", &etag);
1354 response.add_header("Accept-Ranges", "bytes");
1355 if let Some(encoding) = content_encoding {
1356 response.add_header("Content-Encoding", encoding);
1357 response.add_header("Vary", "Accept-Encoding");
1358 }
1359 if let Some(ttl) = cache_ttl_secs {
1360 response.add_header(
1361 "Cache-Control",
1362 &format!("public, max-age={ttl}"),
1363 );
1364 }
1365 Ok(response)
1366}
1367
1368fn compute_etag(metadata: &fs::Metadata) -> String {
1369 let modified = metadata
1370 .modified()
1371 .ok()
1372 .and_then(|time| time.duration_since(UNIX_EPOCH).ok())
1373 .map_or(0_u64, |duration| duration.as_secs());
1374 format!("W/\"{:x}-{:x}\"", metadata.len(), modified)
1375}
1376
1377fn parse_range_header(
1378 header: Option<&str>,
1379 total_len: usize,
1380) -> Option<(usize, usize)> {
1381 let header = header?;
1382 let value = header.strip_prefix("bytes=")?;
1383 let (start_str, end_str) = value.split_once('-')?;
1384 if start_str.is_empty() && end_str.is_empty() {
1385 return None;
1386 }
1387 if start_str.is_empty() {
1388 let suffix_len = end_str.parse::<usize>().ok()?;
1389 if suffix_len == 0 || suffix_len > total_len {
1390 return None;
1391 }
1392 return Some((total_len - suffix_len, total_len - 1));
1393 }
1394 let start = start_str.parse::<usize>().ok()?;
1395 let end = if end_str.is_empty() {
1396 total_len.checked_sub(1)?
1397 } else {
1398 end_str.parse::<usize>().ok()?
1399 };
1400 if start > end || end >= total_len {
1401 return None;
1402 }
1403 Some((start, end))
1404}
1405
1406fn generate_404_response(
1416 document_root: &Path,
1417) -> Result<Response, ServerError> {
1418 let not_found_path = document_root.join("404/index.html");
1419 let contents = if not_found_path.is_file() {
1420 fs::read(not_found_path)?
1421 } else {
1422 b"404 Not Found".to_vec()
1423 };
1424 let mut response = Response::new(404, "NOT FOUND", contents);
1425 response.add_header("Content-Type", "text/html");
1426 Ok(response)
1427}
1428
1429fn generate_head_response(
1444 request: &Request,
1445 document_root: &Path,
1446) -> Result<Response, ServerError> {
1447 let full_response = generate_response(request, document_root)?;
1449
1450 let mut head_response = Response::new(
1452 full_response.status_code,
1453 &full_response.status_text,
1454 Vec::new(), );
1456
1457 for (name, value) in &full_response.headers {
1459 head_response.add_header(name, value);
1460 }
1461
1462 let content_length = full_response.body.len().to_string();
1464 head_response.add_header("Content-Length", &content_length);
1465
1466 Ok(head_response)
1467}
1468
1469fn generate_options_response(
1482 _request: &Request,
1483) -> Result<Response, ServerError> {
1484 let mut response = Response::new(200, "OK", Vec::new());
1485 response.add_header("Allow", "GET, HEAD, OPTIONS");
1486 response.add_header("Content-Length", "0");
1487 Ok(response)
1488}
1489
1490fn generate_method_not_allowed_response() -> Response {
1499 let mut response = Response::new(
1500 405,
1501 "METHOD NOT ALLOWED",
1502 b"Method Not Allowed".to_vec(),
1503 );
1504 response.add_header("Allow", "GET, HEAD, OPTIONS");
1505 response.add_header("Content-Type", "text/plain");
1506 response.add_header("Content-Length", "18");
1507 response
1508}
1509
1510fn response_from_error(
1511 error: &ServerError,
1512 document_root: &Path,
1513) -> Response {
1514 match error {
1515 ServerError::InvalidRequest(message) => {
1516 let mut response = Response::new(
1517 400,
1518 "BAD REQUEST",
1519 message.as_bytes().to_vec(),
1520 );
1521 response.add_header("Content-Type", "text/plain");
1522 response
1523 }
1524 ServerError::Forbidden(message) => {
1525 let mut response = Response::new(
1526 403,
1527 "FORBIDDEN",
1528 message.as_bytes().to_vec(),
1529 );
1530 response.add_header("Content-Type", "text/plain");
1531 response
1532 }
1533 ServerError::NotFound(_) => {
1534 generate_404_response(document_root).unwrap_or_else(|_| {
1535 let mut response = Response::new(
1536 404,
1537 "NOT FOUND",
1538 b"404 Not Found".to_vec(),
1539 );
1540 response.add_header("Content-Type", "text/plain");
1541 response
1542 })
1543 }
1544 ServerError::Io(_)
1545 | ServerError::Custom(_)
1546 | ServerError::TaskFailed(_) => {
1547 let mut response = Response::new(
1548 500,
1549 "INTERNAL SERVER ERROR",
1550 b"Internal Server Error".to_vec(),
1551 );
1552 response.add_header("Content-Type", "text/plain");
1553 response
1554 }
1555 }
1556}
1557
1558fn apply_response_policies(
1559 mut response: Response,
1560 server: &Server,
1561 request: &Request,
1562) -> Response {
1563 if let Some(headers) = server.custom_headers.as_ref() {
1564 for (name, value) in headers {
1565 response.add_header(name, value);
1566 }
1567 }
1568
1569 if server.cors_enabled.unwrap_or(false) {
1570 let allow_origin = server
1571 .cors_origins
1572 .as_ref()
1573 .and_then(|origins| origins.first())
1574 .map(String::as_str)
1575 .unwrap_or("*");
1576 response
1577 .add_header("Access-Control-Allow-Origin", allow_origin);
1578 response.add_header(
1579 "Access-Control-Allow-Methods",
1580 "GET, HEAD, OPTIONS",
1581 );
1582 response.add_header("Access-Control-Allow-Headers", "*");
1583
1584 if request.method().eq_ignore_ascii_case("OPTIONS") {
1585 response.add_header("Access-Control-Max-Age", "600");
1586 }
1587 }
1588
1589 if let Some(ttl) = server.static_cache_ttl_secs {
1590 let has_cache_control =
1591 response.headers.iter().any(|(name, _)| {
1592 name.eq_ignore_ascii_case("cache-control")
1593 });
1594 if !has_cache_control {
1595 if is_probably_immutable_asset_path(request.path()) {
1596 response.add_header(
1597 "Cache-Control",
1598 "public, max-age=31536000, immutable",
1599 );
1600 } else {
1601 response.add_header(
1602 "Cache-Control",
1603 &format!("public, max-age={ttl}"),
1604 );
1605 }
1606 }
1607 }
1608
1609 response
1610}
1611
1612fn is_probably_immutable_asset_path(path: &str) -> bool {
1613 let file = path.rsplit('/').next().unwrap_or(path);
1614 let Some((stem, _ext)) = file.rsplit_once('.') else {
1615 return false;
1616 };
1617 let Some(hash) = stem.rsplit('-').next() else {
1618 return false;
1619 };
1620 hash.len() >= 8 && hash.chars().all(|c| c.is_ascii_hexdigit())
1621}
1622
1623fn get_content_type(path: &Path) -> &'static str {
1633 match path.extension().and_then(std::ffi::OsStr::to_str) {
1634 Some("html") | Some("htm") => "text/html",
1636 Some("css") => "text/css",
1637 Some("js") | Some("mjs") => "application/javascript",
1638 Some("ts") => "application/typescript",
1639 Some("json") => "application/json",
1640 Some("xml") => "application/xml",
1641 Some("txt") => "text/plain",
1642 Some("md") | Some("markdown") => "text/markdown",
1643 Some("yaml") | Some("yml") => "application/x-yaml",
1644 Some("toml") => "application/toml",
1645
1646 Some("png") => "image/png",
1648 Some("jpg") | Some("jpeg") => "image/jpeg",
1649 Some("gif") => "image/gif",
1650 Some("svg") => "image/svg+xml",
1651 Some("ico") => "image/x-icon",
1652
1653 Some("webp") => "image/webp",
1655 Some("avif") => "image/avif",
1656 Some("heic") | Some("heif") => "image/heic",
1657 Some("jxl") => "image/jxl",
1658 Some("bmp") => "image/bmp",
1659 Some("tiff") | Some("tif") => "image/tiff",
1660
1661 Some("wasm") => "application/wasm",
1663
1664 Some("woff") => "font/woff",
1666 Some("woff2") => "font/woff2",
1667 Some("ttf") => "font/ttf",
1668 Some("otf") => "font/otf",
1669 Some("eot") => "application/vnd.ms-fontobject",
1670
1671 Some("mp3") => "audio/mpeg",
1673 Some("wav") => "audio/wav",
1674 Some("ogg") => "audio/ogg",
1675 Some("opus") => "audio/opus",
1676 Some("flac") => "audio/flac",
1677 Some("m4a") => "audio/mp4",
1678 Some("aac") => "audio/aac",
1679
1680 Some("mp4") => "video/mp4",
1682 Some("webm") => "video/webm",
1683 Some("av1") => "video/av1",
1684 Some("avi") => "video/x-msvideo",
1685 Some("mov") => "video/quicktime",
1686
1687 Some("pdf") => "application/pdf",
1689 Some("zip") => "application/zip",
1690 Some("tar") => "application/x-tar",
1691 Some("gz") => "application/gzip",
1692
1693 Some("map") => "application/json", Some("webmanifest") => "application/manifest+json",
1696
1697 _ => "application/octet-stream",
1699 }
1700}
1701
1702#[cfg(test)]
1703mod tests {
1704 use super::*;
1705 use std::fs::File;
1706 use std::io;
1707 use std::io::Read;
1708 use std::io::Write;
1709 use std::net::{TcpListener, TcpStream};
1710 use tempfile::TempDir;
1711
1712 fn setup_test_directory() -> TempDir {
1713 let temp_dir = TempDir::new().unwrap();
1714 let root_path = temp_dir.path();
1715
1716 let mut index_file =
1718 File::create(root_path.join("index.html")).unwrap();
1719 index_file
1720 .write_all(b"<html><body>Hello, World!</body></html>")
1721 .unwrap();
1722
1723 fs::create_dir(root_path.join("404")).unwrap();
1725 let mut not_found_file =
1726 File::create(root_path.join("404/index.html")).unwrap();
1727 not_found_file
1728 .write_all(b"<html><body>404 Not Found</body></html>")
1729 .unwrap();
1730
1731 fs::create_dir(root_path.join("subdir")).unwrap();
1733 let mut subdir_index_file =
1734 File::create(root_path.join("subdir/index.html")).unwrap();
1735 subdir_index_file
1736 .write_all(b"<html><body>Subdirectory Index</body></html>")
1737 .unwrap();
1738
1739 temp_dir
1740 }
1741
1742 fn roundtrip_handle_connection(
1743 server: &Server,
1744 request: &[u8],
1745 ) -> String {
1746 let listener = TcpListener::bind("127.0.0.1:0").expect("bind");
1747 let addr = listener.local_addr().expect("addr");
1748 let server_clone = server.clone();
1749 let handle = thread::spawn(move || {
1750 let (stream, _) = listener.accept().expect("accept");
1751 handle_connection(stream, &server_clone).expect("handle");
1752 });
1753
1754 let mut client = TcpStream::connect(addr).expect("connect");
1755 client.write_all(request).expect("write");
1756 let mut response = String::new();
1757 let _ = client.read_to_string(&mut response).expect("read");
1758 handle.join().expect("join");
1759 response
1760 }
1761
1762 fn connected_pair() -> (TcpStream, TcpStream) {
1763 let listener = TcpListener::bind("127.0.0.1:0").expect("bind");
1764 let addr = listener.local_addr().expect("addr");
1765 let client = TcpStream::connect(addr).expect("connect");
1766 let (server, _) = listener.accept().expect("accept");
1767 (server, client)
1768 }
1769
1770 #[test]
1771 fn test_server_creation() {
1772 let server = Server::new("127.0.0.1:8080", "/var/www");
1773 assert_eq!(server.address, "127.0.0.1:8080");
1774 assert_eq!(server.document_root, PathBuf::from("/var/www"));
1775 }
1776
1777 #[test]
1778 fn test_get_content_type() {
1779 assert_eq!(
1780 get_content_type(Path::new("test.html")),
1781 "text/html"
1782 );
1783 assert_eq!(
1784 get_content_type(Path::new("page.htm")),
1785 "text/html"
1786 );
1787 assert_eq!(
1788 get_content_type(Path::new("style.css")),
1789 "text/css"
1790 );
1791 assert_eq!(
1792 get_content_type(Path::new("script.js")),
1793 "application/javascript"
1794 );
1795 assert_eq!(
1796 get_content_type(Path::new("data.json")),
1797 "application/json"
1798 );
1799 assert_eq!(
1800 get_content_type(Path::new("image.png")),
1801 "image/png"
1802 );
1803 assert_eq!(
1804 get_content_type(Path::new("photo.jpg")),
1805 "image/jpeg"
1806 );
1807 assert_eq!(
1808 get_content_type(Path::new("animation.gif")),
1809 "image/gif"
1810 );
1811 assert_eq!(
1812 get_content_type(Path::new("icon.svg")),
1813 "image/svg+xml"
1814 );
1815 assert_eq!(
1816 get_content_type(Path::new("unknown.xyz")),
1817 "application/octet-stream"
1818 );
1819 }
1820
1821 #[test]
1822 fn test_modern_content_types() {
1823 assert_eq!(
1825 get_content_type(Path::new("image.webp")),
1826 "image/webp"
1827 );
1828 assert_eq!(
1829 get_content_type(Path::new("image.avif")),
1830 "image/avif"
1831 );
1832 assert_eq!(
1833 get_content_type(Path::new("image.heic")),
1834 "image/heic"
1835 );
1836 assert_eq!(
1837 get_content_type(Path::new("image.heif")),
1838 "image/heic"
1839 );
1840 assert_eq!(
1841 get_content_type(Path::new("image.jxl")),
1842 "image/jxl"
1843 );
1844
1845 assert_eq!(
1847 get_content_type(Path::new("module.wasm")),
1848 "application/wasm"
1849 );
1850
1851 assert_eq!(
1853 get_content_type(Path::new("script.ts")),
1854 "application/typescript"
1855 );
1856 assert_eq!(
1857 get_content_type(Path::new("module.mjs")),
1858 "application/javascript"
1859 );
1860 assert_eq!(
1861 get_content_type(Path::new("README.md")),
1862 "text/markdown"
1863 );
1864 assert_eq!(
1865 get_content_type(Path::new("config.yaml")),
1866 "application/x-yaml"
1867 );
1868 assert_eq!(
1869 get_content_type(Path::new("config.yml")),
1870 "application/x-yaml"
1871 );
1872 assert_eq!(
1873 get_content_type(Path::new("Cargo.toml")),
1874 "application/toml"
1875 );
1876
1877 assert_eq!(
1879 get_content_type(Path::new("audio.opus")),
1880 "audio/opus"
1881 );
1882 assert_eq!(
1883 get_content_type(Path::new("audio.flac")),
1884 "audio/flac"
1885 );
1886
1887 assert_eq!(
1889 get_content_type(Path::new("video.av1")),
1890 "video/av1"
1891 );
1892
1893 assert_eq!(
1895 get_content_type(Path::new("script.js.map")),
1896 "application/json"
1897 );
1898 assert_eq!(
1899 get_content_type(Path::new("manifest.webmanifest")),
1900 "application/manifest+json"
1901 );
1902 }
1903
1904 #[test]
1905 fn test_generate_response() {
1906 let temp_dir = setup_test_directory();
1907 let document_root = temp_dir.path();
1908
1909 let root_request = Request {
1911 method: "GET".to_string(),
1912 path: "/".to_string(),
1913 version: "HTTP/1.1".to_string(),
1914 headers: HashMap::new(),
1915 };
1916
1917 let root_response =
1918 generate_response(&root_request, document_root).unwrap();
1919 assert_eq!(root_response.status_code, 200);
1920 assert_eq!(root_response.status_text, "OK");
1921 assert!(
1922 root_response.body.starts_with(
1923 b"<html><body>Hello, World!</body></html>"
1924 )
1925 );
1926
1927 let file_request = Request {
1929 method: "GET".to_string(),
1930 path: "/index.html".to_string(),
1931 version: "HTTP/1.1".to_string(),
1932 headers: HashMap::new(),
1933 };
1934
1935 let file_response =
1936 generate_response(&file_request, document_root).unwrap();
1937 assert_eq!(file_response.status_code, 200);
1938 assert_eq!(file_response.status_text, "OK");
1939 assert!(
1940 file_response.body.starts_with(
1941 b"<html><body>Hello, World!</body></html>"
1942 )
1943 );
1944
1945 let subdir_request = Request {
1947 method: "GET".to_string(),
1948 path: "/subdir/".to_string(),
1949 version: "HTTP/1.1".to_string(),
1950 headers: HashMap::new(),
1951 };
1952
1953 let subdir_response =
1954 generate_response(&subdir_request, document_root).unwrap();
1955 assert_eq!(subdir_response.status_code, 200);
1956 assert_eq!(subdir_response.status_text, "OK");
1957 assert!(subdir_response.body.starts_with(
1958 b"<html><body>Subdirectory Index</body></html>"
1959 ));
1960
1961 let not_found_request = Request {
1963 method: "GET".to_string(),
1964 path: "/nonexistent.html".to_string(),
1965 version: "HTTP/1.1".to_string(),
1966 headers: HashMap::new(),
1967 };
1968
1969 let not_found_response =
1970 generate_response(¬_found_request, document_root)
1971 .unwrap();
1972 assert_eq!(not_found_response.status_code, 404);
1973 assert_eq!(not_found_response.status_text, "NOT FOUND");
1974 assert!(
1975 not_found_response.body.starts_with(
1976 b"<html><body>404 Not Found</body></html>"
1977 )
1978 );
1979
1980 let traversal_request = Request {
1982 method: "GET".to_string(),
1983 path: "/../outside.html".to_string(),
1984 version: "HTTP/1.1".to_string(),
1985 headers: HashMap::new(),
1986 };
1987
1988 let traversal_response =
1989 generate_response(&traversal_request, document_root);
1990 assert!(matches!(
1991 traversal_response,
1992 Err(ServerError::Forbidden(_))
1993 ));
1994 }
1995
1996 #[test]
1997 fn test_server_builder() {
1998 let server = Server::builder()
2000 .address("127.0.0.1:8080")
2001 .document_root("/var/www")
2002 .enable_cors()
2003 .custom_header("X-Custom", "test-value")
2004 .request_timeout(Duration::from_secs(30))
2005 .build()
2006 .unwrap();
2007
2008 assert_eq!(server.address, "127.0.0.1:8080");
2009 assert_eq!(server.document_root, PathBuf::from("/var/www"));
2010 assert_eq!(server.cors_enabled, Some(true));
2011 assert_eq!(
2012 server.request_timeout,
2013 Some(Duration::from_secs(30))
2014 );
2015
2016 let headers = server.custom_headers.unwrap();
2018 assert_eq!(
2019 headers.get("X-Custom"),
2020 Some(&"test-value".to_string())
2021 );
2022
2023 let result = ServerBuilder::new().build();
2025 assert!(result.is_err());
2026
2027 let server_with_origins = Server::builder()
2029 .address("127.0.0.1:9000")
2030 .document_root("/tmp")
2031 .cors_origins(vec!["https://example.com".to_string()])
2032 .build()
2033 .unwrap();
2034
2035 assert_eq!(server_with_origins.cors_enabled, Some(true));
2036 assert_eq!(
2037 server_with_origins.cors_origins,
2038 Some(vec!["https://example.com".to_string()])
2039 );
2040 }
2041
2042 #[test]
2043 fn test_graceful_shutdown() {
2044 let shutdown = ShutdownSignal::new(Duration::from_secs(5));
2046
2047 assert!(!shutdown.is_shutdown_requested());
2049 assert_eq!(shutdown.active_connection_count(), 0);
2050
2051 shutdown.connection_started();
2053 assert_eq!(shutdown.active_connection_count(), 1);
2054
2055 shutdown.connection_started();
2056 assert_eq!(shutdown.active_connection_count(), 2);
2057
2058 shutdown.connection_finished();
2059 assert_eq!(shutdown.active_connection_count(), 1);
2060
2061 shutdown.connection_finished();
2062 assert_eq!(shutdown.active_connection_count(), 0);
2063
2064 shutdown.shutdown();
2066 assert!(shutdown.is_shutdown_requested());
2067
2068 let graceful = shutdown.wait_for_shutdown();
2070 assert!(graceful);
2071 }
2072
2073 #[test]
2074 fn test_shutdown_signal_timeout() {
2075 let shutdown = ShutdownSignal::new(Duration::from_millis(100));
2076
2077 shutdown.connection_started();
2079 shutdown.shutdown();
2080
2081 let graceful = shutdown.wait_for_shutdown();
2083 assert!(!graceful); }
2085
2086 #[test]
2087 fn test_thread_pool() {
2088 use std::sync::Arc;
2089 use std::sync::atomic::AtomicUsize;
2090 use std::sync::mpsc;
2091
2092 let pool = ThreadPool::new(4);
2093 let counter = Arc::new(AtomicUsize::new(0));
2094 let (tx, rx) = mpsc::channel();
2095
2096 for _ in 0..10 {
2098 let counter_clone = Arc::clone(&counter);
2099 let tx_clone = tx.clone();
2100
2101 pool.execute(move || {
2102 let _ = counter_clone.fetch_add(1, Ordering::SeqCst);
2103 tx_clone.send(()).unwrap();
2104 });
2105 }
2106
2107 for _ in 0..10 {
2109 rx.recv().unwrap();
2110 }
2111
2112 assert_eq!(counter.load(Ordering::SeqCst), 10);
2113 }
2114
2115 #[test]
2116 fn test_connection_pool() {
2117 let pool = ConnectionPool::new(2);
2118 assert_eq!(pool.active_count(), 0);
2119
2120 let guard1 = pool.acquire().unwrap();
2122 assert_eq!(pool.active_count(), 1);
2123
2124 let guard2 = pool.acquire().unwrap();
2126 assert_eq!(pool.active_count(), 2);
2127
2128 let result = pool.acquire();
2130 assert!(result.is_err());
2131 assert_eq!(pool.active_count(), 2);
2132
2133 drop(guard1);
2135 assert_eq!(pool.active_count(), 1);
2136
2137 let _guard3 = pool.acquire().unwrap();
2139 assert_eq!(pool.active_count(), 2);
2140
2141 drop(guard2);
2143 drop(_guard3);
2144 assert_eq!(pool.active_count(), 0);
2145 }
2146
2147 #[test]
2148 fn test_thread_pool_concurrent_execution() {
2149 use std::sync::Arc;
2150 use std::sync::atomic::AtomicUsize;
2151 use std::sync::mpsc;
2152 use std::time::Instant;
2153
2154 let pool = ThreadPool::new(4);
2155 let counter = Arc::new(AtomicUsize::new(0));
2156 let (tx, rx) = mpsc::channel();
2157
2158 let start_time = Instant::now();
2159
2160 for i in 0..100 {
2162 let counter_clone = Arc::clone(&counter);
2163 let tx_clone = tx.clone();
2164
2165 pool.execute(move || {
2166 thread::sleep(Duration::from_millis(10));
2168 let _ = counter_clone.fetch_add(1, Ordering::SeqCst);
2169 tx_clone.send(i).unwrap();
2170 });
2171 }
2172
2173 for _ in 0..100 {
2175 let _ = rx.recv().unwrap();
2176 }
2177
2178 let elapsed = start_time.elapsed();
2179 assert_eq!(counter.load(Ordering::SeqCst), 100);
2180
2181 assert!(
2183 elapsed.as_millis() < 800,
2184 "Thread pool should provide concurrency benefits"
2185 );
2186 }
2187
2188 #[test]
2189 fn test_connection_pool_backpressure() {
2190 let pool = ConnectionPool::new(2);
2191
2192 let _guard1 = pool.acquire().unwrap();
2194 let _guard2 = pool.acquire().unwrap();
2195 assert_eq!(pool.active_count(), 2);
2196
2197 let result = pool.acquire();
2199 assert!(result.is_err());
2200 assert_eq!(
2201 result.unwrap_err().kind(),
2202 io::ErrorKind::WouldBlock
2203 );
2204 }
2205
2206 #[test]
2207 fn test_service_unavailable_response() {
2208 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
2210 let addr = listener.local_addr().unwrap();
2211
2212 let _ = thread::spawn(move || {
2213 let (stream, _) = listener.accept().unwrap();
2214 send_service_unavailable(stream).unwrap();
2215 });
2216
2217 let mut client_stream = TcpStream::connect(addr).unwrap();
2218 let mut response = String::new();
2219 let _ = client_stream.read_to_string(&mut response).unwrap();
2220
2221 assert!(response.contains("503"));
2222 assert!(response.contains("SERVICE UNAVAILABLE"));
2223 assert!(response.contains("Service temporarily unavailable"));
2224 assert!(response.contains("Retry-After: 1"));
2225 }
2226
2227 #[test]
2228 fn test_service_unavailable_send_failure_is_mapped() {
2229 use std::net::Shutdown;
2230
2231 let listener = TcpListener::bind("127.0.0.1:0").expect("bind");
2232 let addr = listener.local_addr().expect("addr");
2233
2234 let t = thread::spawn(move || {
2235 let (stream, _) = listener.accept().expect("accept");
2236 stream.shutdown(Shutdown::Write).expect("shutdown");
2237 let err =
2238 send_service_unavailable(stream).expect_err("err");
2239 assert!(
2240 err.to_string().contains("Failed to send response")
2241 );
2242 });
2243
2244 let _client = TcpStream::connect(addr).expect("connect");
2245 t.join().expect("join");
2246 }
2247
2248 #[test]
2249 fn test_response_from_error_variants() {
2250 let temp_dir = setup_test_directory();
2251 let root = temp_dir.path();
2252
2253 let bad = response_from_error(
2254 &ServerError::InvalidRequest("bad".to_string()),
2255 root,
2256 );
2257 assert_eq!(bad.status_code, 400);
2258
2259 let forbidden = response_from_error(
2260 &ServerError::Forbidden("no".to_string()),
2261 root,
2262 );
2263 assert_eq!(forbidden.status_code, 403);
2264
2265 let not_found = response_from_error(
2266 &ServerError::NotFound("missing".to_string()),
2267 root,
2268 );
2269 assert_eq!(not_found.status_code, 404);
2270
2271 let internal = response_from_error(
2272 &ServerError::TaskFailed("boom".to_string()),
2273 root,
2274 );
2275 assert_eq!(internal.status_code, 500);
2276 }
2277
2278 #[test]
2279 fn test_apply_response_policies_with_cors_and_headers() {
2280 let mut headers = HashMap::new();
2281 let _ = headers
2282 .insert("X-App".to_string(), "http-handle".to_string());
2283 let server = Server::builder()
2284 .address("127.0.0.1:0")
2285 .document_root(".")
2286 .enable_cors()
2287 .cors_origins(vec!["https://example.com".to_string()])
2288 .custom_headers(headers)
2289 .build()
2290 .expect("server builder");
2291
2292 let request = Request {
2293 method: "OPTIONS".to_string(),
2294 path: "/".to_string(),
2295 version: "HTTP/1.1".to_string(),
2296 headers: HashMap::new(),
2297 };
2298 let response = apply_response_policies(
2299 Response::new(200, "OK", Vec::new()),
2300 &server,
2301 &request,
2302 );
2303
2304 let has_origin = response.headers.iter().any(|(k, v)| {
2305 k.eq_ignore_ascii_case("Access-Control-Allow-Origin")
2306 && v == "https://example.com"
2307 });
2308 let has_custom = response
2309 .headers
2310 .iter()
2311 .any(|(k, v)| k == "X-App" && v == "http-handle");
2312 let has_max_age = response.headers.iter().any(|(k, _)| {
2313 k.eq_ignore_ascii_case("Access-Control-Max-Age")
2314 });
2315
2316 assert!(has_origin);
2317 assert!(has_custom);
2318 assert!(has_max_age);
2319 }
2320
2321 #[test]
2322 fn test_thread_pool_debug_representation() {
2323 let pool = ThreadPool::new(1);
2324 let rendered = format!("{pool:?}");
2325 assert!(rendered.contains("ThreadPool"));
2326 assert!(rendered.contains("<Sender<Job>>"));
2327 }
2328
2329 #[test]
2330 fn test_server_getters_expose_builder_config() {
2331 let mut headers = HashMap::new();
2332 let _ =
2333 headers.insert("X-Test".to_string(), "value".to_string());
2334
2335 let server = Server::builder()
2336 .address("127.0.0.1:9001")
2337 .document_root("/tmp")
2338 .enable_cors()
2339 .cors_origins(vec!["https://example.com".to_string()])
2340 .custom_headers(headers)
2341 .request_timeout(Duration::from_secs(5))
2342 .connection_timeout(Duration::from_secs(7))
2343 .build()
2344 .expect("server");
2345
2346 assert_eq!(server.cors_enabled(), Some(true));
2347 assert_eq!(
2348 server.cors_origins(),
2349 &Some(vec!["https://example.com".to_string()])
2350 );
2351 assert_eq!(
2352 server.request_timeout(),
2353 Some(Duration::from_secs(5))
2354 );
2355 assert_eq!(
2356 server.connection_timeout(),
2357 Some(Duration::from_secs(7))
2358 );
2359 assert_eq!(server.address(), "127.0.0.1:9001");
2360 assert_eq!(server.document_root(), &PathBuf::from("/tmp"));
2361 assert_eq!(
2362 server
2363 .custom_headers()
2364 .as_ref()
2365 .and_then(|h| h.get("X-Test")),
2366 Some(&"value".to_string())
2367 );
2368 }
2369
2370 #[test]
2371 fn test_start_variants_return_bind_errors_for_in_use_address() {
2372 let occupied = TcpListener::bind("127.0.0.1:0").expect("bind");
2373 let addr = occupied.local_addr().expect("addr").to_string();
2374 let server = Server::new(&addr, ".");
2375
2376 assert!(server.start().is_err());
2377 assert!(
2378 server
2379 .start_with_graceful_shutdown(Duration::from_millis(1))
2380 .is_err()
2381 );
2382 assert!(server.start_with_thread_pool(1).is_err());
2383 assert!(server.start_with_pooling(1, 1).is_err());
2384 }
2385
2386 #[test]
2387 fn test_start_with_shutdown_signal_and_ready_reports_bound_address()
2388 {
2389 let root = setup_test_directory();
2390 let server = Server::builder()
2391 .address("127.0.0.1:0")
2392 .document_root(root.path().to_str().expect("path"))
2393 .build()
2394 .expect("server");
2395
2396 let (ready_tx, ready_rx) = mpsc::channel::<String>();
2397 let shutdown =
2398 Arc::new(ShutdownSignal::new(Duration::from_secs(2)));
2399 let shutdown_for_server = shutdown.clone();
2400 let server_for_thread = server.clone();
2401
2402 let handle = thread::spawn(move || {
2403 server_for_thread
2404 .start_with_shutdown_signal_and_ready(
2405 shutdown_for_server,
2406 move |addr| {
2407 let _ = ready_tx.send(addr);
2408 },
2409 )
2410 .expect("server run");
2411 });
2412
2413 let bound_addr = ready_rx
2414 .recv_timeout(Duration::from_secs(2))
2415 .expect("bound address");
2416 assert!(bound_addr.starts_with("127.0.0.1:"));
2417 assert_ne!(bound_addr, "127.0.0.1:0");
2418
2419 let mut stream =
2420 TcpStream::connect(&bound_addr).expect("connect");
2421 stream
2422 .write_all(
2423 b"GET /index.html HTTP/1.1\r\nHost: localhost\r\n\r\n",
2424 )
2425 .expect("write");
2426 let mut response = String::new();
2427 let _ = stream.read_to_string(&mut response);
2428 assert!(response.starts_with("HTTP/1.1 200 OK"));
2429
2430 shutdown.shutdown();
2431 handle.join().expect("join");
2432 }
2433
2434 #[test]
2435 fn test_generate_response_falls_back_to_builtin_404_without_page() {
2436 let temp_dir = TempDir::new().expect("tmp");
2437 fs::write(temp_dir.path().join("index.html"), b"index")
2438 .expect("write");
2439 fs::create_dir(temp_dir.path().join("empty-dir")).expect("dir");
2440
2441 let request = Request {
2442 method: "GET".to_string(),
2443 path: "/empty-dir/".to_string(),
2444 version: "HTTP/1.1".to_string(),
2445 headers: HashMap::new(),
2446 };
2447
2448 let response = generate_response(&request, temp_dir.path())
2449 .expect("response");
2450 assert_eq!(response.status_code, 404);
2451 assert_eq!(response.body, b"404 Not Found".to_vec());
2452 }
2453
2454 #[cfg(unix)]
2455 #[test]
2456 fn test_response_from_error_not_found_fallback_when_custom_404_unreadable()
2457 {
2458 use std::os::unix::fs::PermissionsExt;
2459
2460 let temp_dir = TempDir::new().expect("tmp");
2461 let custom_404_dir = temp_dir.path().join("404");
2462 fs::create_dir(&custom_404_dir).expect("create 404 dir");
2463 let custom_404 = custom_404_dir.join("index.html");
2464 fs::write(&custom_404, b"custom").expect("write 404");
2465 fs::set_permissions(
2466 &custom_404,
2467 fs::Permissions::from_mode(0o000),
2468 )
2469 .expect("chmod");
2470
2471 let response = response_from_error(
2472 &ServerError::NotFound("missing".to_string()),
2473 temp_dir.path(),
2474 );
2475
2476 assert_eq!(response.status_code, 404);
2477 assert_eq!(response.status_text, "NOT FOUND");
2478 assert_eq!(response.body, b"404 Not Found".to_vec());
2479 }
2480
2481 #[test]
2482 fn test_handle_connection_options_and_parse_error_paths() {
2483 let root = setup_test_directory();
2484 let root_str = root.path().to_str().expect("root path");
2485 let server = Server::builder()
2486 .address("127.0.0.1:0")
2487 .document_root(root_str)
2488 .build()
2489 .expect("server");
2490
2491 let options_response = roundtrip_handle_connection(
2492 &server,
2493 b"OPTIONS / HTTP/1.1\r\nHost: localhost\r\n\r\n",
2494 );
2495 assert!(options_response.starts_with("HTTP/1.1 200 OK"));
2496 assert!(options_response.contains("Allow: GET, HEAD, OPTIONS"));
2497
2498 let head_response = roundtrip_handle_connection(
2499 &server,
2500 b"HEAD / HTTP/1.1\r\nHost: localhost\r\n\r\n",
2501 );
2502 assert!(head_response.starts_with("HTTP/1.1 200 OK"));
2503 assert!(head_response.contains("Content-Length:"));
2504
2505 let method_not_allowed = roundtrip_handle_connection(
2506 &server,
2507 b"POST / HTTP/1.1\r\nHost: localhost\r\n\r\n",
2508 );
2509 assert!(
2510 method_not_allowed
2511 .starts_with("HTTP/1.1 405 METHOD NOT ALLOWED")
2512 );
2513
2514 let traversal_response = roundtrip_handle_connection(
2515 &server,
2516 b"GET /../outside HTTP/1.1\r\nHost: localhost\r\n\r\n",
2517 );
2518 assert!(
2519 traversal_response.starts_with("HTTP/1.1 403 FORBIDDEN")
2520 );
2521
2522 let bad_response =
2523 roundtrip_handle_connection(&server, b"BAD\r\n\r\n");
2524 assert!(bad_response.starts_with("HTTP/1.1 400 BAD REQUEST"));
2525 }
2526
2527 #[test]
2528 fn test_generate_response_supports_etag_and_range() {
2529 let temp_dir = setup_test_directory();
2530 let root = temp_dir.path();
2531
2532 let mut headers = HashMap::new();
2533 let _ = headers
2534 .insert("range".to_string(), "bytes=0-4".to_string());
2535 let range_request = Request {
2536 method: "GET".to_string(),
2537 path: "/index.html".to_string(),
2538 version: "HTTP/1.1".to_string(),
2539 headers,
2540 };
2541 let range_response =
2542 generate_response(&range_request, root).expect("range");
2543 assert_eq!(range_response.status_code, 206);
2544 assert!(range_response.body.starts_with(b"<html"));
2545 let etag = range_response
2546 .headers
2547 .iter()
2548 .find(|(name, _)| name.eq_ignore_ascii_case("etag"))
2549 .map(|(_, value)| value.clone())
2550 .expect("etag");
2551
2552 let mut headers = HashMap::new();
2553 let _ = headers.insert("if-none-match".to_string(), etag);
2554 let conditional_request = Request {
2555 method: "GET".to_string(),
2556 path: "/index.html".to_string(),
2557 version: "HTTP/1.1".to_string(),
2558 headers,
2559 };
2560 let conditional_response =
2561 generate_response(&conditional_request, root)
2562 .expect("conditional");
2563 assert_eq!(conditional_response.status_code, 304);
2564 assert!(conditional_response.body.is_empty());
2565 }
2566
2567 #[test]
2568 fn test_metrics_and_rate_limiting() {
2569 let root = setup_test_directory();
2570 let server = Server::builder()
2571 .address("127.0.0.1:0")
2572 .document_root(root.path().to_str().expect("path"))
2573 .rate_limit_per_minute(1)
2574 .build()
2575 .expect("server");
2576
2577 let _ = roundtrip_handle_connection(
2578 &server,
2579 b"GET /index.html HTTP/1.1\r\nHost: localhost\r\n\r\n",
2580 );
2581 let limited = roundtrip_handle_connection(
2582 &server,
2583 b"GET /index.html HTTP/1.1\r\nHost: localhost\r\n\r\n",
2584 );
2585 assert!(limited.starts_with("HTTP/1.1 429 TOO MANY REQUESTS"));
2586
2587 let metrics = roundtrip_handle_connection(
2588 &server,
2589 b"GET /metrics HTTP/1.1\r\nHost: localhost\r\n\r\n",
2590 );
2591 assert!(metrics.starts_with("HTTP/1.1 200 OK"));
2592 assert!(metrics.contains("http_handle_requests_total"));
2593 }
2594
2595 #[test]
2596 fn test_trigger_shutdown_from_slot_helper() {
2597 let shutdown =
2598 Arc::new(ShutdownSignal::new(Duration::from_secs(1)));
2599 let slot = Mutex::new(Some(shutdown.clone()));
2600 assert!(!shutdown.is_shutdown_requested());
2601 Server::trigger_shutdown_from_slot(&slot);
2602 assert!(shutdown.is_shutdown_requested());
2603 }
2604
2605 #[test]
2606 fn test_handle_shutdown_signal_helper() {
2607 let shutdown =
2608 Arc::new(ShutdownSignal::new(Duration::from_secs(1)));
2609 let slot =
2610 SHUTDOWN_SIGNAL_SLOT.get_or_init(|| Mutex::new(None));
2611 if let Ok(mut guard) = slot.lock() {
2612 *guard = Some(shutdown.clone());
2613 }
2614 Server::handle_shutdown_signal();
2615 assert!(shutdown.is_shutdown_requested());
2616 }
2617
2618 #[test]
2619 fn test_accept_loop_helpers_cover_ok_and_err_paths() {
2620 let root = setup_test_directory();
2621 let server = Server::builder()
2622 .address("127.0.0.1:0")
2623 .document_root(root.path().to_str().expect("path"))
2624 .build()
2625 .expect("server");
2626
2627 Server::run_basic_accept_loop(
2628 vec![Err(io::Error::other("incoming failed"))],
2629 server.clone(),
2630 );
2631 let pool = ThreadPool::new(1);
2632 Server::run_thread_pool_accept_loop(
2633 vec![Err(io::Error::other("incoming failed"))],
2634 server.clone(),
2635 &pool,
2636 );
2637 Server::run_pooling_accept_loop(
2638 vec![Err(io::Error::other("incoming failed"))],
2639 server.clone(),
2640 &pool,
2641 Arc::new(ConnectionPool::new(1)),
2642 );
2643
2644 let (server_stream, mut client) = connected_pair();
2645 client.write_all(b"BAD\r\n\r\n").expect("write");
2646 Server::run_basic_accept_loop(
2647 vec![Ok(server_stream)],
2648 server.clone(),
2649 );
2650 let mut response = String::new();
2651 let _ = client.read_to_string(&mut response).expect("read");
2652 assert!(response.starts_with("HTTP/1.1 400 BAD REQUEST"));
2653
2654 let (server_stream, mut client) = connected_pair();
2655 client.write_all(b"BAD\r\n\r\n").expect("write");
2656 Server::run_thread_pool_accept_loop(
2657 vec![Ok(server_stream)],
2658 server.clone(),
2659 &pool,
2660 );
2661 let mut response = String::new();
2662 let _ = client.read_to_string(&mut response).expect("read");
2663 assert!(response.starts_with("HTTP/1.1 400 BAD REQUEST"));
2664
2665 let (server_stream, mut client) = connected_pair();
2666 client.write_all(b"BAD\r\n\r\n").expect("write");
2667 Server::run_pooling_accept_loop(
2668 vec![Ok(server_stream)],
2669 server.clone(),
2670 &pool,
2671 Arc::new(ConnectionPool::new(1)),
2672 );
2673 let mut response = String::new();
2674 let _ = client.read_to_string(&mut response).expect("read");
2675 assert!(response.starts_with("HTTP/1.1 400 BAD REQUEST"));
2676
2677 let (server_stream, mut client) = connected_pair();
2678 Server::run_pooling_accept_loop(
2679 vec![Ok(server_stream)],
2680 server,
2681 &pool,
2682 Arc::new(ConnectionPool::new(0)),
2683 );
2684 let mut response = String::new();
2685 let _ = client.read_to_string(&mut response).expect("read");
2686 assert!(
2687 response.starts_with("HTTP/1.1 503 SERVICE UNAVAILABLE")
2688 );
2689 }
2690
2691 #[test]
2692 fn test_immutable_cache_control_policy() {
2693 let root = setup_test_directory();
2694 let server = Server::builder()
2695 .address("127.0.0.1:0")
2696 .document_root(root.path().to_str().expect("path"))
2697 .static_cache_ttl_secs(60)
2698 .build()
2699 .expect("server");
2700
2701 let request = Request {
2702 method: "GET".to_string(),
2703 path: "/assets/app-abcdef12.js".to_string(),
2704 version: "HTTP/1.1".to_string(),
2705 headers: HashMap::new(),
2706 };
2707 let response = apply_response_policies(
2708 Response::new(200, "OK", b"ok".to_vec()),
2709 &server,
2710 &request,
2711 );
2712 assert!(response.headers.iter().any(|(name, value)| {
2713 name.eq_ignore_ascii_case("cache-control")
2714 && value.contains("immutable")
2715 }));
2716 }
2717
2718 #[test]
2719 fn test_zstd_precompressed_asset_is_served() {
2720 let root = setup_test_directory();
2721 let file = root.path().join("index.html.zst");
2722 fs::write(&file, b"zstd-data").expect("write");
2723
2724 let mut headers = HashMap::new();
2725 let _ = headers.insert(
2726 "accept-encoding".to_string(),
2727 "zstd,gzip".to_string(),
2728 );
2729 let request = Request {
2730 method: "GET".to_string(),
2731 path: "/index.html".to_string(),
2732 version: "HTTP/1.1".to_string(),
2733 headers,
2734 };
2735
2736 let response =
2737 generate_response_with_cache(&request, root.path(), None)
2738 .expect("response");
2739 assert!(response.headers.iter().any(|(name, value)| {
2740 name.eq_ignore_ascii_case("content-encoding")
2741 && value.eq_ignore_ascii_case("zstd")
2742 }));
2743 assert_eq!(response.body, b"zstd-data");
2744 }
2745}