Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

http: record the number of bytes read when response writer is hijacked #6173

Merged
merged 3 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions modules/caddyhttp/responsewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ type responseRecorder struct {
size int
wroteHeader bool
stream bool

readSize *int
}

// NewResponseRecorder returns a new ResponseRecorder that can be
Expand Down Expand Up @@ -240,6 +242,12 @@ func (rr *responseRecorder) FlushError() error {
return nil
}

// Private interface so it can only be used in this package
// #TODO: maybe export it later
func (rr *responseRecorder) setReadSize(size *int) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing in a *int is a bit weird. Can we just pass in a int? Any reason we need a pointer at all, actually?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pointer points to the member of the wrapped request body. So when it's updated, the corresponding log entry is updated as well. int requires more changes to the code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. Is that thread-safe? (Would it ever be used across goroutines?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The net.Conn interface is not thread-safe by default unless specified explicitly. But the only race I know of is when caddy sends websocket goaway frame, which is not thread-safe in the first place. But clients are expected to reconnect anyway, the stats are a bit off (goaway frames are small), nothing breaking.

rr.readSize = size
}

func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
//nolint:bodyclose
conn, brw, err := http.NewResponseController(rr.ResponseWriterWrapper).Hijack()
Expand All @@ -249,6 +257,15 @@ func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
// Per http documentation, returned bufio.Writer is empty, but bufio.Read maybe not
conn = &hijackedConn{conn, rr}
brw.Writer.Reset(conn)

buffered := brw.Reader.Buffered()
if buffered != 0 {
conn.(*hijackedConn).updateReadSize(buffered)
data, _ := brw.Peek(buffered)
brw.Reader.Reset(io.MultiReader(bytes.NewReader(data), conn))
} else {
brw.Reader.Reset(conn)
}
return conn, brw, nil
}

Expand All @@ -258,6 +275,24 @@ type hijackedConn struct {
rr *responseRecorder
}

func (hc *hijackedConn) updateReadSize(n int) {
if hc.rr.readSize != nil {
*hc.rr.readSize += n
}
}

func (hc *hijackedConn) Read(p []byte) (int, error) {
n, err := hc.Conn.Read(p)
hc.updateReadSize(n)
return n, err
}

func (hc *hijackedConn) WriteTo(w io.Writer) (int64, error) {
n, err := io.Copy(w, hc.Conn)
hc.updateReadSize(int(n))
mholt marked this conversation as resolved.
Show resolved Hide resolved
return n, err
}

func (hc *hijackedConn) Write(p []byte) (int, error) {
n, err := hc.Conn.Write(p)
hc.rr.size += n
Expand Down Expand Up @@ -298,4 +333,6 @@ var (
_ io.ReaderFrom = (*ResponseWriterWrapper)(nil)
_ io.ReaderFrom = (*responseRecorder)(nil)
_ io.ReaderFrom = (*hijackedConn)(nil)

_ io.WriterTo = (*hijackedConn)(nil)
)
5 changes: 5 additions & 0 deletions modules/caddyhttp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,11 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Body != nil {
bodyReader = &lengthReader{Source: r.Body}
r.Body = bodyReader

// should always be true, private interface can only be referenced in the same package
if setReadSizer, ok := wrec.(interface{ setReadSize(*int) }); ok {
setReadSizer.setReadSize(&bodyReader.Length)
}
}

// capture the original version of the request
Expand Down
Loading