From b68719ff0bc18c1075a9e6fd39de7f2da2425087 Mon Sep 17 00:00:00 2001 From: Atharva Sharma Date: Fri, 10 Apr 2026 20:13:09 +0530 Subject: [PATCH 1/4] feat: add AutoHead feature to automatically register HEAD routes for GET requests --- echo.go | 86 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/echo.go b/echo.go index 5e706f8bd..bb7ce4aab 100644 --- a/echo.go +++ b/echo.go @@ -54,6 +54,7 @@ import ( "os" "os/signal" "path/filepath" + "strconv" "strings" "sync" "sync/atomic" @@ -100,6 +101,7 @@ type Echo struct { // formParseMaxMemory is passed to Context for multipart form parsing (See http.Request.ParseMultipartForm) formParseMaxMemory int64 + AutoHead bool } // JSONSerializer is the interface that encodes and decodes JSON to and from interfaces. @@ -288,6 +290,11 @@ type Config struct { // FormParseMaxMemory is default value for memory limit that is used // when parsing multipart forms (See (*http.Request).ParseMultipartForm) FormParseMaxMemory int64 + + // AutoHead enables automatic registration of HEAD routes for GET routes. + // When enabled, a HEAD request to a GET-only path will be handled automatically + // using the same handler as GET, with the response body suppressed. + AutoHead bool } // NewWithConfig creates an instance of Echo with given configuration. @@ -326,6 +333,9 @@ func NewWithConfig(config Config) *Echo { if config.FormParseMaxMemory > 0 { e.formParseMaxMemory = config.FormParseMaxMemory } + if config.AutoHead { + e.AutoHead = config.AutoHead + } return e } @@ -421,6 +431,67 @@ func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler { } } +// headResponseWriter wraps an http.ResponseWriter and suppresses the response body +// while preserving headers and status code. Used for automatic HEAD route handling. +// It counts the bytes that would have been written so we can set Content-Length accurately. +type headResponseWriter struct { + http.ResponseWriter + bytesWritten int64 + statusCode int + wroteHeader bool +} + +// Write intercepts writes to the response body and counts bytes without actually writing them. +func (hw *headResponseWriter) Write(b []byte) (int, error) { + if !hw.wroteHeader { + hw.statusCode = http.StatusOK + hw.wroteHeader = true + } + hw.bytesWritten += int64(len(b)) + // Return success without actually writing the body for HEAD requests + return len(b), nil +} + +// WriteHeader intercepts the status code but still writes it to the underlying ResponseWriter. +func (hw *headResponseWriter) WriteHeader(statusCode int) { + if !hw.wroteHeader { + hw.statusCode = statusCode + hw.wroteHeader = true + hw.ResponseWriter.WriteHeader(statusCode) + } +} + +// Unwrap returns the underlying http.ResponseWriter for compatibility with echo.Response unwrapping. +func (hw *headResponseWriter) Unwrap() http.ResponseWriter { + return hw.ResponseWriter +} + +func wrapHeadHandler(handler HandlerFunc) HandlerFunc { + return func(c *Context) error { + if c.Request().Method != http.MethodHead { + return handler(c) + } + originalWriter := c.Response() + headWriter := &headResponseWriter{ResponseWriter: originalWriter} + + c.SetResponse(headWriter) + defer func() { + c.SetResponse(originalWriter) + }() + err := handler(c) + + if headWriter.bytesWritten > 0 { + originalWriter.Header().Set("Content-Length", strconv.FormatInt(headWriter.bytesWritten, 10)) + } + + if !headWriter.wroteHeader && headWriter.statusCode > 0 { + originalWriter.WriteHeader(headWriter.statusCode) + } + + return err + } +} + // Pre adds middleware to the chain which is run before router tries to find matching route. // Meaning middleware is executed even for 404 (not found) cases. func (e *Echo) Pre(middleware ...MiddlewareFunc) { @@ -634,6 +705,20 @@ func (e *Echo) add(route Route) (RouteInfo, error) { if paramsCount > e.contextPathParamAllocSize.Load() { e.contextPathParamAllocSize.Store(paramsCount) } + + // Auto-register HEAD route for GET if AutoHead is enabled + if e.AutoHead && route.Method == http.MethodGet { + headRoute := Route{ + Method: http.MethodHead, + Path: route.Path, + Handler: wrapHeadHandler(route.Handler), + Middlewares: route.Middlewares, + Name: route.Name, + } + // Attempt to add HEAD route, but ignore errors if an explicit HEAD route already exists + _, _ = e.router.Add(headRoute) + } + return ri, nil } @@ -642,6 +727,7 @@ func (e *Echo) add(route Route) (RouteInfo, error) { func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { ri, err := e.add( Route{ + Method: method, Path: path, Handler: handler, From 39c6c71ef56a3b65bf8dd5b17afbcf84aa5a5c99 Mon Sep 17 00:00:00 2001 From: Atharva Sharma Date: Fri, 10 Apr 2026 20:13:33 +0530 Subject: [PATCH 2/4] test: add tests for AutoHead feature --- echo_test.go | 173 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 173 insertions(+) diff --git a/echo_test.go b/echo_test.go index b5045e111..950a52042 100644 --- a/echo_test.go +++ b/echo_test.go @@ -1233,6 +1233,159 @@ func TestDefaultHTTPErrorHandler_CommitedResponse(t *testing.T) { assert.Equal(t, http.StatusOK, resp.Code) } +func TestAutoHeadRoute(t *testing.T) { + tests := []struct { + name string + autoHead bool + method string + wantBody bool + wantCode int + wantCLen bool // expect Content-Length header + }{ + { + name: "AutoHead disabled - HEAD returns 405", + autoHead: false, + method: http.MethodHead, + wantCode: http.StatusMethodNotAllowed, + wantBody: false, + }, + { + name: "AutoHead enabled - HEAD returns 200 with Content-Length", + autoHead: true, + method: http.MethodHead, + wantCode: http.StatusOK, + wantBody: false, + wantCLen: true, + }, + { + name: "GET request works normally with AutoHead enabled", + autoHead: true, + method: http.MethodGet, + wantCode: http.StatusOK, + wantBody: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create Echo instance with AutoHead configuration + e := New() + e.AutoHead = tt.autoHead + + // Register a simple GET route + testBody := "Hello, World!" + e.GET("/hello", func(c *Context) error { + return c.String(http.StatusOK, testBody) + }) + + // Create request and response + req := httptest.NewRequest(tt.method, "/hello", nil) + rec := httptest.NewRecorder() + + // Serve the request + e.ServeHTTP(rec, req) + + // Verify status code + if rec.Code != tt.wantCode { + t.Errorf("expected status %d, got %d", tt.wantCode, rec.Code) + } + + // Verify response body + if tt.wantBody { + if rec.Body.String() != testBody { + t.Errorf("expected body %q, got %q", testBody, rec.Body.String()) + } + } else { + if rec.Body.String() != "" { + t.Errorf("expected empty body for HEAD, got %q", rec.Body.String()) + } + } + + // Verify Content-Length header for HEAD + if tt.wantCLen && tt.method == http.MethodHead { + clen := rec.Header().Get("Content-Length") + if clen == "" { + t.Error("expected Content-Length header for HEAD request") + } + } + }) + } +} + +func TestAutoHeadExplicitHeadTakesPrecedence(t *testing.T) { + e := New() + e.AutoHead = true + + // Register explicit HEAD route FIRST with custom behavior + e.HEAD("/api/users", func(c *Context) error { + c.Response().Header().Set("X-Custom-Header", "explicit-head") + return c.NoContent(http.StatusOK) + }) + + // Then register GET route - AutoHead will try to add a HEAD route but fail silently + // since one already exists + e.GET("/api/users", func(c *Context) error { + return c.JSON(http.StatusOK, map[string]string{"name": "John"}) + }) + + // Test that the explicit HEAD route behavior is preserved + req := httptest.NewRequest(http.MethodHead, "/api/users", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", rec.Code) + } + + if rec.Header().Get("X-Custom-Header") != "explicit-head" { + t.Error("expected explicit HEAD route to be used") + } + + // Verify body is empty + if rec.Body.String() != "" { + t.Errorf("expected empty body for HEAD, got %q", rec.Body.String()) + } +} + +func TestAutoHeadWithMiddleware(t *testing.T) { + e := New() + e.AutoHead = true + + // Add request logger middleware + middlewareExecuted := false + e.Use(func(next HandlerFunc) HandlerFunc { + return func(c *Context) error { + middlewareExecuted = true + c.Response().Header().Set("X-Middleware", "executed") + return next(c) + } + }) + + // Register GET route + e.GET("/test", func(c *Context) error { + return c.String(http.StatusOK, "test response") + }) + + // Test HEAD request goes through middleware + req := httptest.NewRequest(http.MethodHead, "/test", nil) + rec := httptest.NewRecorder() + + middlewareExecuted = false + e.ServeHTTP(rec, req) + + if !middlewareExecuted { + t.Error("middleware should execute for automatic HEAD route") + } + + if rec.Header().Get("X-Middleware") != "executed" { + t.Error("middleware header not set") + } + + if rec.Body.String() != "" { + t.Errorf("expected empty body for HEAD, got %q", rec.Body.String()) + } +} + func benchmarkEchoRoutes(b *testing.B, routes []testRoute) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -1278,3 +1431,23 @@ func BenchmarkEchoGitHubAPIMisses(b *testing.B) { func BenchmarkEchoParseAPI(b *testing.B) { benchmarkEchoRoutes(b, parseAPI) } + +func BenchmarkAutoHeadRoute(b *testing.B) { + e := New() + e.AutoHead = true + + e.GET("/bench", func(c *Context) error { + return c.String(http.StatusOK, "benchmark response body") + }) + + req := httptest.NewRequest(http.MethodHead, "/bench", nil) + rec := httptest.NewRecorder() + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rec.Body.Reset() + e.ServeHTTP(rec, req) + } +} From ea953fdcccece2c9cb6a865a37763e21521e9f63 Mon Sep 17 00:00:00 2001 From: Atharva Sharma Date: Wed, 15 Apr 2026 15:20:47 +0530 Subject: [PATCH 3/4] style: remove extra blank line --- echo.go | 1 - 1 file changed, 1 deletion(-) diff --git a/echo.go b/echo.go index bb7ce4aab..4e1e2306e 100644 --- a/echo.go +++ b/echo.go @@ -727,7 +727,6 @@ func (e *Echo) add(route Route) (RouteInfo, error) { func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { ri, err := e.add( Route{ - Method: method, Path: path, Handler: handler, From 49b5ca949659a21a1c9c35de40793b8cd174ab15 Mon Sep 17 00:00:00 2001 From: toim Date: Sun, 19 Apr 2026 20:43:24 +0300 Subject: [PATCH 4/4] AutoHandleHEAD enables automatic handling of HTTP HEAD requests by falling back to the corresponding GET route --- context.go | 10 +-- echo.go | 85 ------------------------ echo_test.go | 180 +++++++++++++-------------------------------------- response.go | 87 +++++++++++++++++++++++++ router.go | 32 +++++++-- 5 files changed, 163 insertions(+), 231 deletions(-) diff --git a/context.go b/context.go index ec7fdd998..510f51855 100644 --- a/context.go +++ b/context.go @@ -467,13 +467,9 @@ func (c *Context) json(code int, i any, indent string) error { // as JSONSerializer.Serialize can fail, and in that case we need to delay sending status code to the client until // (global) error handler decides correct status code for the error to be sent to the client. // For that we need to use writer that can store the proposed status code until the first Write is called. - if r, err := UnwrapResponse(c.response); err == nil { - r.Status = code - } else { - resp := c.Response() - c.SetResponse(&delayedStatusWriter{ResponseWriter: resp, status: code}) - defer c.SetResponse(resp) - } + resp := c.Response() + c.SetResponse(&delayedStatusWriter{ResponseWriter: resp, status: code}) + defer c.SetResponse(resp) return c.echo.JSONSerializer.Serialize(c, i, indent) } diff --git a/echo.go b/echo.go index 4e1e2306e..5e706f8bd 100644 --- a/echo.go +++ b/echo.go @@ -54,7 +54,6 @@ import ( "os" "os/signal" "path/filepath" - "strconv" "strings" "sync" "sync/atomic" @@ -101,7 +100,6 @@ type Echo struct { // formParseMaxMemory is passed to Context for multipart form parsing (See http.Request.ParseMultipartForm) formParseMaxMemory int64 - AutoHead bool } // JSONSerializer is the interface that encodes and decodes JSON to and from interfaces. @@ -290,11 +288,6 @@ type Config struct { // FormParseMaxMemory is default value for memory limit that is used // when parsing multipart forms (See (*http.Request).ParseMultipartForm) FormParseMaxMemory int64 - - // AutoHead enables automatic registration of HEAD routes for GET routes. - // When enabled, a HEAD request to a GET-only path will be handled automatically - // using the same handler as GET, with the response body suppressed. - AutoHead bool } // NewWithConfig creates an instance of Echo with given configuration. @@ -333,9 +326,6 @@ func NewWithConfig(config Config) *Echo { if config.FormParseMaxMemory > 0 { e.formParseMaxMemory = config.FormParseMaxMemory } - if config.AutoHead { - e.AutoHead = config.AutoHead - } return e } @@ -431,67 +421,6 @@ func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler { } } -// headResponseWriter wraps an http.ResponseWriter and suppresses the response body -// while preserving headers and status code. Used for automatic HEAD route handling. -// It counts the bytes that would have been written so we can set Content-Length accurately. -type headResponseWriter struct { - http.ResponseWriter - bytesWritten int64 - statusCode int - wroteHeader bool -} - -// Write intercepts writes to the response body and counts bytes without actually writing them. -func (hw *headResponseWriter) Write(b []byte) (int, error) { - if !hw.wroteHeader { - hw.statusCode = http.StatusOK - hw.wroteHeader = true - } - hw.bytesWritten += int64(len(b)) - // Return success without actually writing the body for HEAD requests - return len(b), nil -} - -// WriteHeader intercepts the status code but still writes it to the underlying ResponseWriter. -func (hw *headResponseWriter) WriteHeader(statusCode int) { - if !hw.wroteHeader { - hw.statusCode = statusCode - hw.wroteHeader = true - hw.ResponseWriter.WriteHeader(statusCode) - } -} - -// Unwrap returns the underlying http.ResponseWriter for compatibility with echo.Response unwrapping. -func (hw *headResponseWriter) Unwrap() http.ResponseWriter { - return hw.ResponseWriter -} - -func wrapHeadHandler(handler HandlerFunc) HandlerFunc { - return func(c *Context) error { - if c.Request().Method != http.MethodHead { - return handler(c) - } - originalWriter := c.Response() - headWriter := &headResponseWriter{ResponseWriter: originalWriter} - - c.SetResponse(headWriter) - defer func() { - c.SetResponse(originalWriter) - }() - err := handler(c) - - if headWriter.bytesWritten > 0 { - originalWriter.Header().Set("Content-Length", strconv.FormatInt(headWriter.bytesWritten, 10)) - } - - if !headWriter.wroteHeader && headWriter.statusCode > 0 { - originalWriter.WriteHeader(headWriter.statusCode) - } - - return err - } -} - // Pre adds middleware to the chain which is run before router tries to find matching route. // Meaning middleware is executed even for 404 (not found) cases. func (e *Echo) Pre(middleware ...MiddlewareFunc) { @@ -705,20 +634,6 @@ func (e *Echo) add(route Route) (RouteInfo, error) { if paramsCount > e.contextPathParamAllocSize.Load() { e.contextPathParamAllocSize.Store(paramsCount) } - - // Auto-register HEAD route for GET if AutoHead is enabled - if e.AutoHead && route.Method == http.MethodGet { - headRoute := Route{ - Method: http.MethodHead, - Path: route.Path, - Handler: wrapHeadHandler(route.Handler), - Middlewares: route.Middlewares, - Name: route.Name, - } - // Attempt to add HEAD route, but ignore errors if an explicit HEAD route already exists - _, _ = e.router.Add(headRoute) - } - return ri, nil } diff --git a/echo_test.go b/echo_test.go index 950a52042..6847e56bd 100644 --- a/echo_test.go +++ b/echo_test.go @@ -1233,157 +1233,87 @@ func TestDefaultHTTPErrorHandler_CommitedResponse(t *testing.T) { assert.Equal(t, http.StatusOK, resp.Code) } -func TestAutoHeadRoute(t *testing.T) { +func TestRouterAutoHandleHEADFullHTTPHandlerFlow(t *testing.T) { tests := []struct { - name string - autoHead bool - method string - wantBody bool - wantCode int - wantCLen bool // expect Content-Length header + name string + givenAutoHandleHEAD bool + whenMethod string + expectBody string + expectCode int + expectContentLength string }{ { - name: "AutoHead disabled - HEAD returns 405", - autoHead: false, - method: http.MethodHead, - wantCode: http.StatusMethodNotAllowed, - wantBody: false, + name: "AutoHandleHEAD disabled - HEAD returns 405", + givenAutoHandleHEAD: false, + whenMethod: http.MethodHead, + expectCode: http.StatusMethodNotAllowed, + expectBody: "", }, { - name: "AutoHead enabled - HEAD returns 200 with Content-Length", - autoHead: true, - method: http.MethodHead, - wantCode: http.StatusOK, - wantBody: false, - wantCLen: true, + name: "AutoHandleHEAD enabled - HEAD returns 200 with Content-Length", + givenAutoHandleHEAD: true, + whenMethod: http.MethodHead, + expectCode: http.StatusOK, + expectBody: "", + expectContentLength: "4", }, { - name: "GET request works normally with AutoHead enabled", - autoHead: true, - method: http.MethodGet, - wantCode: http.StatusOK, - wantBody: true, + name: "GET request works normally with AutoHandleHEAD enabled", + givenAutoHandleHEAD: true, + whenMethod: http.MethodGet, + expectCode: http.StatusOK, + expectBody: "test", }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Create Echo instance with AutoHead configuration - e := New() - e.AutoHead = tt.autoHead + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + e := NewWithConfig(Config{ + Router: NewRouter(RouterConfig{ + AutoHandleHEAD: tc.givenAutoHandleHEAD, + }), + }) - // Register a simple GET route - testBody := "Hello, World!" e.GET("/hello", func(c *Context) error { - return c.String(http.StatusOK, testBody) + return c.String(http.StatusOK, "test") }) - // Create request and response - req := httptest.NewRequest(tt.method, "/hello", nil) + req := httptest.NewRequest(tc.whenMethod, "/hello", nil) rec := httptest.NewRecorder() - // Serve the request e.ServeHTTP(rec, req) - // Verify status code - if rec.Code != tt.wantCode { - t.Errorf("expected status %d, got %d", tt.wantCode, rec.Code) - } - - // Verify response body - if tt.wantBody { - if rec.Body.String() != testBody { - t.Errorf("expected body %q, got %q", testBody, rec.Body.String()) - } - } else { - if rec.Body.String() != "" { - t.Errorf("expected empty body for HEAD, got %q", rec.Body.String()) - } - } - - // Verify Content-Length header for HEAD - if tt.wantCLen && tt.method == http.MethodHead { - clen := rec.Header().Get("Content-Length") - if clen == "" { - t.Error("expected Content-Length header for HEAD request") - } - } + assert.Equal(t, tc.expectCode, rec.Code) + assert.Equal(t, tc.expectContentLength, rec.Header().Get(HeaderContentLength)) + assert.Equal(t, tc.expectBody, rec.Body.String()) }) } } func TestAutoHeadExplicitHeadTakesPrecedence(t *testing.T) { - e := New() - e.AutoHead = true + e := NewWithConfig(Config{ + Router: NewRouter(RouterConfig{ + AutoHandleHEAD: true, + }), + }) // Register explicit HEAD route FIRST with custom behavior e.HEAD("/api/users", func(c *Context) error { c.Response().Header().Set("X-Custom-Header", "explicit-head") - return c.NoContent(http.StatusOK) + return c.NoContent(http.StatusTeapot) }) - // Then register GET route - AutoHead will try to add a HEAD route but fail silently - // since one already exists e.GET("/api/users", func(c *Context) error { - return c.JSON(http.StatusOK, map[string]string{"name": "John"}) + return c.JSON(http.StatusNotFound, map[string]string{"name": "John"}) }) - // Test that the explicit HEAD route behavior is preserved req := httptest.NewRequest(http.MethodHead, "/api/users", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) - if rec.Code != http.StatusOK { - t.Errorf("expected status 200, got %d", rec.Code) - } - - if rec.Header().Get("X-Custom-Header") != "explicit-head" { - t.Error("expected explicit HEAD route to be used") - } - - // Verify body is empty - if rec.Body.String() != "" { - t.Errorf("expected empty body for HEAD, got %q", rec.Body.String()) - } -} - -func TestAutoHeadWithMiddleware(t *testing.T) { - e := New() - e.AutoHead = true - - // Add request logger middleware - middlewareExecuted := false - e.Use(func(next HandlerFunc) HandlerFunc { - return func(c *Context) error { - middlewareExecuted = true - c.Response().Header().Set("X-Middleware", "executed") - return next(c) - } - }) - - // Register GET route - e.GET("/test", func(c *Context) error { - return c.String(http.StatusOK, "test response") - }) - - // Test HEAD request goes through middleware - req := httptest.NewRequest(http.MethodHead, "/test", nil) - rec := httptest.NewRecorder() - - middlewareExecuted = false - e.ServeHTTP(rec, req) - - if !middlewareExecuted { - t.Error("middleware should execute for automatic HEAD route") - } - - if rec.Header().Get("X-Middleware") != "executed" { - t.Error("middleware header not set") - } - - if rec.Body.String() != "" { - t.Errorf("expected empty body for HEAD, got %q", rec.Body.String()) - } + assert.Equal(t, http.StatusTeapot, rec.Code) + assert.Equal(t, "explicit-head", rec.Header().Get("X-Custom-Header")) + assert.Equal(t, "", rec.Body.String()) } func benchmarkEchoRoutes(b *testing.B, routes []testRoute) { @@ -1431,23 +1361,3 @@ func BenchmarkEchoGitHubAPIMisses(b *testing.B) { func BenchmarkEchoParseAPI(b *testing.B) { benchmarkEchoRoutes(b, parseAPI) } - -func BenchmarkAutoHeadRoute(b *testing.B) { - e := New() - e.AutoHead = true - - e.GET("/bench", func(c *Context) error { - return c.String(http.StatusOK, "benchmark response body") - }) - - req := httptest.NewRequest(http.MethodHead, "/bench", nil) - rec := httptest.NewRecorder() - - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - rec.Body.Reset() - e.ServeHTTP(rec, req) - } -} diff --git a/response.go b/response.go index 4da729c47..c018af2cb 100644 --- a/response.go +++ b/response.go @@ -10,6 +10,7 @@ import ( "log/slog" "net" "net/http" + "strconv" ) // Response wraps an http.ResponseWriter and implements its interface to be used @@ -170,3 +171,89 @@ func (w *delayedStatusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { func (w *delayedStatusWriter) Unwrap() http.ResponseWriter { return w.ResponseWriter } + +// headResponseWriter captures the response that a GET handler would produce for a +// rewritten HEAD request, suppresses the body, and preserves response metadata. +// +// The writer buffers status until the downstream handler returns, so it +// can compute a Content-Length value from the number of body bytes that would have +// been written by the GET handler. If the handler already sets Content-Length +// explicitly, that value is preserved. +// +// Flush is intentionally a no-op because emitting headers early would prevent +// finalizing Content-Length after the handler completes. +type headResponseWriter struct { + rw http.ResponseWriter + status int + wroteStatus bool + bodyBytes int64 +} + +func (w *headResponseWriter) Header() http.Header { + return w.rw.Header() +} + +func (w *headResponseWriter) WriteHeader(code int) { + if w.wroteStatus { + return + } + w.wroteStatus = true + w.status = code +} + +func (w *headResponseWriter) Write(b []byte) (int, error) { + if !w.wroteStatus { + w.WriteHeader(http.StatusOK) + } + w.bodyBytes += int64(len(b)) + return len(b), nil +} + +func (w *headResponseWriter) Flush() { + // No-op on purpose. A HEAD response has no body, and flushing early would + // commit headers before Content-Length can be finalized. +} + +func (w *headResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return http.NewResponseController(w.rw).Hijack() +} + +func (w *headResponseWriter) Unwrap() http.ResponseWriter { + return w.rw +} + +func (w *headResponseWriter) commit() { + dst := w.rw.Header() + if dst.Get(HeaderContentLength) == "" && + dst.Get("Transfer-Encoding") == "" && + !statusMustNotHaveBody(w.status) { + dst.Set(HeaderContentLength, strconv.FormatInt(w.bodyBytes, 10)) + } + + // "commit" the Response only when the headers were written otherwise the Echo errorhandler cannot properly handle errors + if w.wroteStatus { + w.rw.WriteHeader(w.status) + } +} + +func statusMustNotHaveBody(code int) bool { + return (code >= 100 && code < 200) || + code == http.StatusNoContent || + code == http.StatusNotModified +} + +func wrapHeadHandler(handler HandlerFunc) HandlerFunc { + return func(c *Context) error { + originalWriter := c.Response() + headWriter := &headResponseWriter{rw: originalWriter} + + c.SetResponse(headWriter) + defer func() { + c.SetResponse(originalWriter) + }() + + err := handler(c) + headWriter.commit() + return err + } +} diff --git a/router.go b/router.go index 48341cb1b..86e2dfd26 100644 --- a/router.go +++ b/router.go @@ -69,6 +69,7 @@ type DefaultRouter struct { allowOverwritingRoute bool unescapePathParamValues bool useEscapedPathForRouting bool + autoHandleHEAD bool } // RouterConfig is configuration options for (default) router @@ -79,6 +80,20 @@ type RouterConfig struct { AllowOverwritingRoute bool UnescapePathParamValues bool UseEscapedPathForMatching bool + + // AutoHandleHEAD enables automatic handling of HTTP HEAD requests by + // falling back to the corresponding GET route. + // + // When enabled, a HEAD request will match the same handler as GET for + // the route, but the response body is suppressed in accordance with + // HTTP semantics. Headers (e.g., Content-Length, Content-Type) are + // preserved as if a GET request was made. + // + // Note that the GET handler is still executed, so any side effects + // (such as database queries or logging) will occur. + // + // Disabled by default. + AutoHandleHEAD bool } // NewRouter returns a new Router instance. @@ -98,6 +113,7 @@ func NewRouter(config RouterConfig) *DefaultRouter { notFoundHandler: notFoundHandler, methodNotAllowedHandler: methodNotAllowedHandler, optionsMethodHandler: optionsMethodHandler, + autoHandleHEAD: config.AutoHandleHEAD, } if config.NotFoundHandler != nil { r.notFoundHandler = config.NotFoundHandler @@ -210,7 +226,7 @@ func (m *routeMethods) set(method string, r *routeMethod) { m.updateAllowHeader() } -func (m *routeMethods) find(method string, fallbackToAny bool) *routeMethod { +func (m *routeMethods) find(method string, fallbackToAny bool, autoHandleHEAD bool) *routeMethod { var r *routeMethod switch method { case http.MethodConnect: @@ -221,6 +237,9 @@ func (m *routeMethods) find(method string, fallbackToAny bool) *routeMethod { r = m.get case http.MethodHead: r = m.head + if autoHandleHEAD && r == nil { + r = m.get + } case http.MethodOptions: r = m.options case http.MethodPatch: @@ -374,7 +393,7 @@ func (r *DefaultRouter) Remove(method string, path string) error { return errors.New("could not find route to remove by given path") } - if mh := nodeToRemove.methods.find(method, false); mh == nil { + if mh := nodeToRemove.methods.find(method, false, false); mh == nil { return errors.New("could not find route to remove by given path and method") } nodeToRemove.setHandler(method, nil) @@ -904,7 +923,7 @@ func (r *DefaultRouter) Route(c *Context) HandlerFunc { if previousBestMatchNode == nil { previousBestMatchNode = currentNode } - if h := currentNode.methods.find(req.Method, true); h != nil { + if h := currentNode.methods.find(req.Method, true, r.autoHandleHEAD); h != nil { matchedRouteMethod = h break } @@ -955,7 +974,7 @@ func (r *DefaultRouter) Route(c *Context) HandlerFunc { searchIndex += len(search) search = "" - if rMethod := currentNode.methods.find(req.Method, true); rMethod != nil { + if rMethod := currentNode.methods.find(req.Method, true, r.autoHandleHEAD); rMethod != nil { matchedRouteMethod = rMethod break } @@ -995,6 +1014,11 @@ func (r *DefaultRouter) Route(c *Context) HandlerFunc { var rInfo *RouteInfo if matchedRouteMethod != nil { rHandler = matchedRouteMethod.handler + if r.autoHandleHEAD && req.Method == http.MethodHead { + rHandler = wrapHeadHandler(rHandler) + // we are not touching rInfo.Method and let it be value from GET routeInfo + } + rPath = matchedRouteMethod.Path rInfo = matchedRouteMethod.RouteInfo } else {