Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,13 +467,9 @@ func (c *Context) json(code int, i any, indent string) error {
// as JSONSerializer.Serialize can fail, and in that case we need to delay sending status code to the client until
// (global) error handler decides correct status code for the error to be sent to the client.
// For that we need to use writer that can store the proposed status code until the first Write is called.
if r, err := UnwrapResponse(c.response); err == nil {
r.Status = code
} else {
resp := c.Response()
c.SetResponse(&delayedStatusWriter{ResponseWriter: resp, status: code})
defer c.SetResponse(resp)
}
resp := c.Response()
c.SetResponse(&delayedStatusWriter{ResponseWriter: resp, status: code})
defer c.SetResponse(resp)

return c.echo.JSONSerializer.Serialize(c, i, indent)
}
Expand Down
83 changes: 83 additions & 0 deletions echo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1233,6 +1233,89 @@ func TestDefaultHTTPErrorHandler_CommitedResponse(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.Code)
}

func TestRouterAutoHandleHEADFullHTTPHandlerFlow(t *testing.T) {
tests := []struct {
name string
givenAutoHandleHEAD bool
whenMethod string
expectBody string
expectCode int
expectContentLength string
}{
{
name: "AutoHandleHEAD disabled - HEAD returns 405",
givenAutoHandleHEAD: false,
whenMethod: http.MethodHead,
expectCode: http.StatusMethodNotAllowed,
expectBody: "",
},
{
name: "AutoHandleHEAD enabled - HEAD returns 200 with Content-Length",
givenAutoHandleHEAD: true,
whenMethod: http.MethodHead,
expectCode: http.StatusOK,
expectBody: "",
expectContentLength: "4",
},
{
name: "GET request works normally with AutoHandleHEAD enabled",
givenAutoHandleHEAD: true,
whenMethod: http.MethodGet,
expectCode: http.StatusOK,
expectBody: "test",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
e := NewWithConfig(Config{
Router: NewRouter(RouterConfig{
AutoHandleHEAD: tc.givenAutoHandleHEAD,
}),
})

e.GET("/hello", func(c *Context) error {
return c.String(http.StatusOK, "test")
})

req := httptest.NewRequest(tc.whenMethod, "/hello", nil)
rec := httptest.NewRecorder()

e.ServeHTTP(rec, req)

assert.Equal(t, tc.expectCode, rec.Code)
assert.Equal(t, tc.expectContentLength, rec.Header().Get(HeaderContentLength))
assert.Equal(t, tc.expectBody, rec.Body.String())
})
}
}

func TestAutoHeadExplicitHeadTakesPrecedence(t *testing.T) {
e := NewWithConfig(Config{
Router: NewRouter(RouterConfig{
AutoHandleHEAD: true,
}),
})

// Register explicit HEAD route FIRST with custom behavior
e.HEAD("/api/users", func(c *Context) error {
c.Response().Header().Set("X-Custom-Header", "explicit-head")
return c.NoContent(http.StatusTeapot)
})

e.GET("/api/users", func(c *Context) error {
return c.JSON(http.StatusNotFound, map[string]string{"name": "John"})
})

req := httptest.NewRequest(http.MethodHead, "/api/users", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)

assert.Equal(t, http.StatusTeapot, rec.Code)
assert.Equal(t, "explicit-head", rec.Header().Get("X-Custom-Header"))
assert.Equal(t, "", rec.Body.String())
}

func benchmarkEchoRoutes(b *testing.B, routes []testRoute) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
Expand Down
87 changes: 87 additions & 0 deletions response.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"log/slog"
"net"
"net/http"
"strconv"
)

// Response wraps an http.ResponseWriter and implements its interface to be used
Expand Down Expand Up @@ -170,3 +171,89 @@ func (w *delayedStatusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
func (w *delayedStatusWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}

// headResponseWriter captures the response that a GET handler would produce for a
// rewritten HEAD request, suppresses the body, and preserves response metadata.
//
// The writer buffers status until the downstream handler returns, so it
// can compute a Content-Length value from the number of body bytes that would have
// been written by the GET handler. If the handler already sets Content-Length
// explicitly, that value is preserved.
//
// Flush is intentionally a no-op because emitting headers early would prevent
// finalizing Content-Length after the handler completes.
type headResponseWriter struct {
rw http.ResponseWriter
status int
wroteStatus bool
bodyBytes int64
}

func (w *headResponseWriter) Header() http.Header {
return w.rw.Header()
}

func (w *headResponseWriter) WriteHeader(code int) {
if w.wroteStatus {
return
}
w.wroteStatus = true
w.status = code
}

func (w *headResponseWriter) Write(b []byte) (int, error) {
if !w.wroteStatus {
w.WriteHeader(http.StatusOK)
}
w.bodyBytes += int64(len(b))
return len(b), nil
}

func (w *headResponseWriter) Flush() {
// No-op on purpose. A HEAD response has no body, and flushing early would
// commit headers before Content-Length can be finalized.
}

func (w *headResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return http.NewResponseController(w.rw).Hijack()
}

func (w *headResponseWriter) Unwrap() http.ResponseWriter {
return w.rw
}

func (w *headResponseWriter) commit() {
dst := w.rw.Header()
if dst.Get(HeaderContentLength) == "" &&
dst.Get("Transfer-Encoding") == "" &&
!statusMustNotHaveBody(w.status) {
dst.Set(HeaderContentLength, strconv.FormatInt(w.bodyBytes, 10))
}

// "commit" the Response only when the headers were written otherwise the Echo errorhandler cannot properly handle errors
if w.wroteStatus {
w.rw.WriteHeader(w.status)
}
}

func statusMustNotHaveBody(code int) bool {
return (code >= 100 && code < 200) ||
code == http.StatusNoContent ||
code == http.StatusNotModified
}

func wrapHeadHandler(handler HandlerFunc) HandlerFunc {
return func(c *Context) error {
originalWriter := c.Response()
headWriter := &headResponseWriter{rw: originalWriter}

c.SetResponse(headWriter)
defer func() {
c.SetResponse(originalWriter)
}()

err := handler(c)
headWriter.commit()
return err
}
}
32 changes: 28 additions & 4 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ type DefaultRouter struct {
allowOverwritingRoute bool
unescapePathParamValues bool
useEscapedPathForRouting bool
autoHandleHEAD bool
}

// RouterConfig is configuration options for (default) router
Expand All @@ -79,6 +80,20 @@ type RouterConfig struct {
AllowOverwritingRoute bool
UnescapePathParamValues bool
UseEscapedPathForMatching bool

// AutoHandleHEAD enables automatic handling of HTTP HEAD requests by
// falling back to the corresponding GET route.
//
// When enabled, a HEAD request will match the same handler as GET for
// the route, but the response body is suppressed in accordance with
// HTTP semantics. Headers (e.g., Content-Length, Content-Type) are
// preserved as if a GET request was made.
//
// Note that the GET handler is still executed, so any side effects
// (such as database queries or logging) will occur.
//
// Disabled by default.
AutoHandleHEAD bool
}

// NewRouter returns a new Router instance.
Expand All @@ -98,6 +113,7 @@ func NewRouter(config RouterConfig) *DefaultRouter {
notFoundHandler: notFoundHandler,
methodNotAllowedHandler: methodNotAllowedHandler,
optionsMethodHandler: optionsMethodHandler,
autoHandleHEAD: config.AutoHandleHEAD,
}
if config.NotFoundHandler != nil {
r.notFoundHandler = config.NotFoundHandler
Expand Down Expand Up @@ -210,7 +226,7 @@ func (m *routeMethods) set(method string, r *routeMethod) {
m.updateAllowHeader()
}

func (m *routeMethods) find(method string, fallbackToAny bool) *routeMethod {
func (m *routeMethods) find(method string, fallbackToAny bool, autoHandleHEAD bool) *routeMethod {
var r *routeMethod
switch method {
case http.MethodConnect:
Expand All @@ -221,6 +237,9 @@ func (m *routeMethods) find(method string, fallbackToAny bool) *routeMethod {
r = m.get
case http.MethodHead:
r = m.head
if autoHandleHEAD && r == nil {
r = m.get
}
case http.MethodOptions:
r = m.options
case http.MethodPatch:
Expand Down Expand Up @@ -374,7 +393,7 @@ func (r *DefaultRouter) Remove(method string, path string) error {
return errors.New("could not find route to remove by given path")
}

if mh := nodeToRemove.methods.find(method, false); mh == nil {
if mh := nodeToRemove.methods.find(method, false, false); mh == nil {
return errors.New("could not find route to remove by given path and method")
}
nodeToRemove.setHandler(method, nil)
Expand Down Expand Up @@ -904,7 +923,7 @@ func (r *DefaultRouter) Route(c *Context) HandlerFunc {
if previousBestMatchNode == nil {
previousBestMatchNode = currentNode
}
if h := currentNode.methods.find(req.Method, true); h != nil {
if h := currentNode.methods.find(req.Method, true, r.autoHandleHEAD); h != nil {
matchedRouteMethod = h
break
}
Expand Down Expand Up @@ -955,7 +974,7 @@ func (r *DefaultRouter) Route(c *Context) HandlerFunc {
searchIndex += len(search)
search = ""

if rMethod := currentNode.methods.find(req.Method, true); rMethod != nil {
if rMethod := currentNode.methods.find(req.Method, true, r.autoHandleHEAD); rMethod != nil {
matchedRouteMethod = rMethod
break
}
Expand Down Expand Up @@ -995,6 +1014,11 @@ func (r *DefaultRouter) Route(c *Context) HandlerFunc {
var rInfo *RouteInfo
if matchedRouteMethod != nil {
rHandler = matchedRouteMethod.handler
if r.autoHandleHEAD && req.Method == http.MethodHead {
rHandler = wrapHeadHandler(rHandler)
// we are not touching rInfo.Method and let it be value from GET routeInfo
}

rPath = matchedRouteMethod.Path
rInfo = matchedRouteMethod.RouteInfo
} else {
Expand Down
Loading