1#[cfg(feature = "http2")]
10#[cfg_attr(docsrs, doc(cfg(feature = "http2")))]
11use crate::error::ServerError;
12#[cfg(feature = "http2")]
13#[cfg_attr(docsrs, doc(cfg(feature = "http2")))]
14use crate::request::Request;
15#[cfg(feature = "http2")]
16#[cfg_attr(docsrs, doc(cfg(feature = "http2")))]
17use crate::server::{Server, build_response_for_request_with_metrics};
18
19#[cfg(feature = "http2")]
25#[cfg_attr(docsrs, doc(cfg(feature = "http2")))]
26pub async fn start_http2(server: Server) -> Result<(), ServerError> {
47 let listener = tokio::net::TcpListener::bind(server.address())
48 .await
49 .map_err(ServerError::from)?;
50
51 loop {
52 let (stream, _) =
53 listener.accept().await.map_err(ServerError::from)?;
54 let server_clone = server.clone();
55 drop(tokio::spawn(async move {
56 if let Err(error) =
57 handle_h2_connection(stream, server_clone).await
58 {
59 eprintln!("HTTP/2 connection error: {}", error);
60 }
61 }));
62 }
63}
64
65#[cfg(feature = "http2")]
66#[cfg_attr(docsrs, doc(cfg(feature = "http2")))]
67async fn handle_h2_connection(
68 stream: tokio::net::TcpStream,
69 server: Server,
70) -> Result<(), ServerError> {
71 let mut connection =
72 h2::server::handshake(stream).await.map_err(|e| {
73 ServerError::Custom(format!("h2 handshake: {e}"))
74 })?;
75
76 while let Some(next) = connection.accept().await {
77 let (request, respond) = next.map_err(|e| {
78 ServerError::Custom(format!("h2 accept: {e}"))
79 })?;
80 let parsed_request = map_h2_request(&request);
81 let response = build_response_for_request_with_metrics(
82 &server,
83 &parsed_request,
84 );
85 send_h2_response(respond, response)?;
86 }
87
88 Ok(())
89}
90
91#[cfg(feature = "http2")]
92#[cfg_attr(docsrs, doc(cfg(feature = "http2")))]
93fn map_h2_request<B>(request: &http::Request<B>) -> Request {
94 let headers = request
95 .headers()
96 .iter()
97 .filter_map(|(name, value)| {
98 value.to_str().ok().map(|value| {
99 (name.as_str().to_ascii_lowercase(), value.to_string())
100 })
101 })
102 .collect();
103
104 let version = match request.version() {
105 http::Version::HTTP_2 => "HTTP/2.0",
106 _ => "HTTP/1.1",
107 };
108
109 Request {
110 method: request.method().as_str().to_string(),
111 path: request.uri().path().to_string(),
112 version: version.to_string(),
113 headers,
114 }
115}
116
117#[cfg(feature = "http2")]
118#[cfg_attr(docsrs, doc(cfg(feature = "http2")))]
119fn send_h2_response(
120 mut respond: h2::server::SendResponse<bytes::Bytes>,
121 response: crate::response::Response,
122) -> Result<(), ServerError> {
123 let head = build_h2_head(&response)?;
124
125 let end_of_stream = response.body.is_empty();
126 let mut stream = respond
127 .send_response(head, end_of_stream)
128 .map_err(|error| {
129 ServerError::Custom(format!(
130 "failed to send h2 response headers: {error}"
131 ))
132 })?;
133
134 if !end_of_stream {
135 stream
136 .send_data(bytes::Bytes::from(response.body), true)
137 .map_err(|error| {
138 ServerError::Custom(format!(
139 "failed to send h2 response body: {error}"
140 ))
141 })?;
142 }
143
144 Ok(())
145}
146
147#[cfg(feature = "http2")]
148#[cfg_attr(docsrs, doc(cfg(feature = "http2")))]
149fn build_h2_head(
150 response: &crate::response::Response,
151) -> Result<http::Response<()>, ServerError> {
152 let mut builder =
153 http::Response::builder().status(response.status_code);
154 for (name, value) in &response.headers {
155 builder = builder.header(name, value);
156 }
157 builder.body(()).map_err(|error| {
158 ServerError::Custom(format!(
159 "failed to build h2 response headers: {error}"
160 ))
161 })
162}
163
164#[cfg(all(test, feature = "http2"))]
165mod tests {
166 use super::*;
167 use bytes::Bytes;
168 use http::Version;
169 use std::io::Write;
170 use std::net::TcpListener;
171 use tempfile::TempDir;
172 use tokio::io::AsyncWriteExt;
173 use tokio::time::{Duration, sleep};
174
175 fn free_addr() -> String {
176 let listener = TcpListener::bind("127.0.0.1:0").expect("bind");
177 let addr = listener.local_addr().expect("addr");
178 drop(listener);
179 addr.to_string()
180 }
181
182 #[tokio::test]
183 async fn http2_server_serves_static_file() {
184 let root = TempDir::new().expect("tmp");
185 std::fs::write(root.path().join("index.html"), b"hello-h2")
186 .expect("write index");
187 std::fs::create_dir(root.path().join("404")).expect("404 dir");
188 std::fs::write(root.path().join("404/index.html"), b"404")
189 .expect("write 404");
190
191 let addr = free_addr();
192 let server = Server::builder()
193 .address(&addr)
194 .document_root(root.path().to_str().expect("path"))
195 .build()
196 .expect("server");
197
198 let task = tokio::spawn(start_http2(server));
199 sleep(Duration::from_millis(40)).await;
200
201 let stream = tokio::net::TcpStream::connect(&addr)
202 .await
203 .expect("connect");
204 let (mut client, connection) =
205 h2::client::handshake(stream).await.expect("handshake");
206 drop(tokio::spawn(connection));
207
208 let request = http::Request::builder()
209 .method("GET")
210 .uri("http://localhost/")
211 .body(())
212 .expect("request");
213 let (response_future, _send_stream) =
214 client.send_request(request, true).expect("send request");
215 let response = response_future.await.expect("response");
216 assert_eq!(response.status().as_u16(), 200);
217
218 let mut body = response.into_body();
219 let mut collected = Vec::new();
220 while let Some(next) = body.data().await {
221 let chunk: Bytes = next.expect("chunk");
222 collected.extend_from_slice(&chunk);
223 }
224
225 assert_eq!(collected, b"hello-h2");
226 task.abort();
227 }
228
229 #[test]
230 fn map_h2_request_preserves_method_path_headers_and_version() {
231 let request = http::Request::builder()
232 .method("GET")
233 .uri("/status")
234 .version(Version::HTTP_2)
235 .header("x-test", "value")
236 .body(())
237 .expect("request");
238 let parsed = map_h2_request(&request);
239 assert_eq!(parsed.method(), "GET");
240 assert_eq!(parsed.path(), "/status");
241 assert_eq!(parsed.version(), "HTTP/2.0");
242 assert_eq!(parsed.header("x-test"), Some("value"));
243 }
244
245 #[test]
246 fn map_h2_request_falls_back_to_http11_for_other_versions() {
247 let request = http::Request::builder()
248 .method("GET")
249 .uri("/legacy")
250 .version(Version::HTTP_11)
251 .body(())
252 .expect("request");
253 let parsed = map_h2_request(&request);
254 assert_eq!(parsed.version(), "HTTP/1.1");
255 }
256
257 #[test]
258 fn build_h2_head_rejects_invalid_header_name() {
259 let mut response =
260 crate::response::Response::new(200, "OK", Vec::new());
261 response.add_header("bad header", "value");
262 let result = build_h2_head(&response);
263 assert!(matches!(result, Err(ServerError::Custom(_))));
264 }
265
266 #[tokio::test]
267 async fn handle_h2_connection_reports_handshake_error_on_invalid_preface()
268 {
269 let root = TempDir::new().expect("tmp");
270 std::fs::write(root.path().join("index.html"), b"hello")
271 .expect("write index");
272 std::fs::create_dir(root.path().join("404")).expect("404 dir");
273 std::fs::write(root.path().join("404/index.html"), b"404")
274 .expect("write 404");
275
276 let addr = free_addr();
277 let listener =
278 tokio::net::TcpListener::bind(&addr).await.expect("bind");
279 let server = Server::builder()
280 .address(&addr)
281 .document_root(root.path().to_str().expect("path"))
282 .build()
283 .expect("server");
284
285 let accept_task = tokio::spawn(async move {
286 let (stream, _) = listener.accept().await.expect("accept");
287 handle_h2_connection(stream, server).await
288 });
289
290 let mut client =
291 std::net::TcpStream::connect(&addr).expect("connect");
292 client
293 .write_all(b"this-is-not-http2")
294 .expect("write invalid preface");
295
296 let result = accept_task.await.expect("join");
297 assert!(matches!(result, Err(ServerError::Custom(_))));
298 }
299
300 #[tokio::test]
301 async fn http2_server_returns_404_for_missing_resource() {
302 let root = TempDir::new().expect("tmp");
303 std::fs::write(root.path().join("index.html"), b"hello-h2")
304 .expect("write index");
305 std::fs::create_dir(root.path().join("404")).expect("404 dir");
306 std::fs::write(root.path().join("404/index.html"), b"404 page")
307 .expect("write 404");
308
309 let addr = free_addr();
310 let server = Server::builder()
311 .address(&addr)
312 .document_root(root.path().to_str().expect("path"))
313 .build()
314 .expect("server");
315
316 let task = tokio::spawn(start_http2(server));
317 sleep(Duration::from_millis(40)).await;
318
319 let stream = tokio::net::TcpStream::connect(&addr)
320 .await
321 .expect("connect");
322 let (mut client, connection) =
323 h2::client::handshake(stream).await.expect("handshake");
324 drop(tokio::spawn(connection));
325
326 let request = http::Request::builder()
327 .method("GET")
328 .uri("http://localhost/does-not-exist")
329 .body(())
330 .expect("request");
331 let (response_future, _send_stream) =
332 client.send_request(request, true).expect("send request");
333 let response = response_future.await.expect("response");
334 assert_eq!(response.status().as_u16(), 404);
335
336 let mut body = response.into_body();
337 let mut collected = Vec::new();
338 while let Some(next) = body.data().await {
339 let chunk: Bytes = next.expect("chunk");
340 collected.extend_from_slice(&chunk);
341 }
342 assert_eq!(collected, b"404 page");
343 task.abort();
344 }
345
346 #[tokio::test]
347 async fn http2_server_returns_405_for_unsupported_method() {
348 let root = TempDir::new().expect("tmp");
349 std::fs::write(root.path().join("index.html"), b"hello-h2")
350 .expect("write index");
351 std::fs::create_dir(root.path().join("404")).expect("404 dir");
352 std::fs::write(root.path().join("404/index.html"), b"404")
353 .expect("write 404");
354
355 let addr = free_addr();
356 let server = Server::builder()
357 .address(&addr)
358 .document_root(root.path().to_str().expect("path"))
359 .build()
360 .expect("server");
361
362 let task = tokio::spawn(start_http2(server));
363 sleep(Duration::from_millis(40)).await;
364
365 let stream = tokio::net::TcpStream::connect(&addr)
366 .await
367 .expect("connect");
368 let (mut client, connection) =
369 h2::client::handshake(stream).await.expect("handshake");
370 drop(tokio::spawn(connection));
371
372 let request = http::Request::builder()
373 .method("POST")
374 .uri("http://localhost/")
375 .body(())
376 .expect("request");
377 let (response_future, _send_stream) =
378 client.send_request(request, true).expect("send request");
379 let response = response_future.await.expect("response");
380 assert_eq!(response.status().as_u16(), 405);
381 task.abort();
382 }
383
384 #[tokio::test]
385 async fn start_http2_handles_invalid_client_preface() {
386 let root = TempDir::new().expect("tmp");
387 std::fs::write(root.path().join("index.html"), b"hello-h2")
388 .expect("write index");
389 std::fs::create_dir(root.path().join("404")).expect("404 dir");
390 std::fs::write(root.path().join("404/index.html"), b"404")
391 .expect("write 404");
392
393 let addr = free_addr();
394 let server = Server::builder()
395 .address(&addr)
396 .document_root(root.path().to_str().expect("path"))
397 .build()
398 .expect("server");
399
400 let task = tokio::spawn(start_http2(server));
401 sleep(Duration::from_millis(40)).await;
402
403 let mut client =
404 std::net::TcpStream::connect(&addr).expect("connect");
405 client
406 .write_all(b"not-http2")
407 .expect("write invalid preface");
408 sleep(Duration::from_millis(40)).await;
409 task.abort();
410 }
411
412 #[tokio::test]
413 async fn handle_h2_connection_returns_ok_when_client_closes_cleanly()
414 {
415 let root = TempDir::new().expect("tmp");
416 std::fs::write(root.path().join("index.html"), b"hello-h2")
417 .expect("write index");
418 std::fs::create_dir(root.path().join("404")).expect("404 dir");
419 std::fs::write(root.path().join("404/index.html"), b"404")
420 .expect("write 404");
421
422 let addr = free_addr();
423 let listener =
424 tokio::net::TcpListener::bind(&addr).await.expect("bind");
425 let server = Server::builder()
426 .address(&addr)
427 .document_root(root.path().to_str().expect("path"))
428 .build()
429 .expect("server");
430
431 let accept_task = tokio::spawn(async move {
432 let (stream, _) = listener.accept().await.expect("accept");
433 handle_h2_connection(stream, server).await
434 });
435
436 let stream = tokio::net::TcpStream::connect(&addr)
437 .await
438 .expect("connect");
439 let (mut client, connection) =
440 h2::client::handshake(stream).await.expect("handshake");
441 let conn_task = tokio::spawn(connection);
442
443 let request = http::Request::builder()
444 .method("GET")
445 .uri("http://localhost/")
446 .body(())
447 .expect("request");
448 let (response_future, _send_stream) =
449 client.send_request(request, true).expect("send request");
450 let _ = response_future.await.expect("response");
451 drop(client);
452 let _ =
453 tokio::time::timeout(Duration::from_millis(500), conn_task)
454 .await;
455
456 let _ = tokio::time::timeout(
457 Duration::from_millis(500),
458 accept_task,
459 )
460 .await;
461 }
462
463 #[tokio::test]
464 async fn handle_h2_connection_maps_accept_errors() {
465 let root = TempDir::new().expect("tmp");
466 std::fs::write(root.path().join("index.html"), b"hello")
467 .expect("write index");
468 std::fs::create_dir(root.path().join("404")).expect("404 dir");
469 std::fs::write(root.path().join("404/index.html"), b"404")
470 .expect("write 404");
471
472 let addr = free_addr();
473 let listener =
474 tokio::net::TcpListener::bind(&addr).await.expect("bind");
475 let server = Server::builder()
476 .address(&addr)
477 .document_root(root.path().to_str().expect("path"))
478 .build()
479 .expect("server");
480
481 let accept_task = tokio::spawn(async move {
482 let (stream, _) = listener.accept().await.expect("accept");
483 handle_h2_connection(stream, server).await
484 });
485
486 let mut client = tokio::net::TcpStream::connect(&addr)
487 .await
488 .expect("connect");
489 client
491 .write_all(b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
492 .await
493 .expect("preface");
494 client
495 .write_all(&[0, 0, 1, 0xff, 0, 0, 0, 0, 0, 0x00])
496 .await
497 .expect("malformed frame");
498 let _ = client.shutdown().await;
499
500 let result = accept_task.await.expect("join");
501 assert!(
502 result.is_ok()
503 || matches!(result, Err(ServerError::Custom(_)))
504 );
505 }
506}