Skip to content

Commit

Permalink
Drain stream and error on trailing data
Browse files Browse the repository at this point in the history
  • Loading branch information
emcfarlane committed Jun 27, 2023
1 parent 849635c commit df2f8da
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 7 deletions.
14 changes: 14 additions & 0 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2138,6 +2138,20 @@ func TestStreamUnexpectedEOF(t *testing.T) {
},
expectCode: connect.CodeInvalidArgument,
expectMsg: "invalid_argument: protocol error: incomplete envelope: unexpected EOF",
}, {
name: "stream_excess_eof",
handler: func(responseWriter http.ResponseWriter, request *http.Request) {
_, _ = responseWriter.Write(head[:])
_, _ = responseWriter.Write(payload)
// Write EOF
_, _ = responseWriter.Write([]byte{2, 0, 0, 0, 2})
_, _ = responseWriter.Write([]byte("{}"))
// Excess payload
_, _ = responseWriter.Write(head[:])
_, _ = responseWriter.Write(payload)
},
expectCode: connect.CodeUnknown,
expectMsg: fmt.Sprintf("unknown: corrupt response: %d extra bytes after end of stream", len(payload)+len(head)),
}}
for _, testcase := range testcases {
testcaseMux[t.Name()+"/"+testcase.name] = testcase.handler
Expand Down
2 changes: 1 addition & 1 deletion duplex_http_call.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func (d *duplexHTTPCall) CloseRead() error {
if d.response == nil {
return nil
}
if err := discard(d.response.Body); err != nil {
if _, err := discard(d.response.Body); err != nil {
_ = d.response.Body.Close()
return wrapIfRSTError(err)
}
Expand Down
6 changes: 6 additions & 0 deletions envelope.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,12 @@ func (r *envelopeReader) Unmarshal(message any) *Error {
if _, err := r.last.Data.ReadFrom(data); err != nil {
return errorf(CodeUnknown, "copy final envelope: %w", err)
}
// Drain the rest of the stream to ensure there is no extra data.
if n, err := discard(r.reader); err != nil {
return errorf(CodeUnknown, "corrupt response: I/O error after end-stream message: %w", err)
} else if n > 0 {
return errorf(CodeUnknown, "corrupt response: %d extra bytes after end of stream", n)
}
return errSpecialEnvelope
}

Expand Down
8 changes: 3 additions & 5 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,16 +283,14 @@ func isCommaOrSpace(c rune) bool {
return c == ',' || c == ' '
}

func discard(reader io.Reader) error {
func discard(reader io.Reader) (int64, error) {
if lr, ok := reader.(*io.LimitedReader); ok {
_, err := io.Copy(io.Discard, lr)
return err
return io.Copy(io.Discard, lr)
}
// We don't want to get stuck throwing data away forever, so limit how much
// we're willing to do here.
lr := &io.LimitedReader{R: reader, N: discardLimit}
_, err := io.Copy(io.Discard, lr)
return err
return io.Copy(io.Discard, lr)
}

// negotiateCompression determines and validates the request compression and
Expand Down
2 changes: 1 addition & 1 deletion protocol_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ func (g *grpcClient) NewConn(
} else {
conn.readTrailers = func(_ *grpcUnmarshaler, call *duplexHTTPCall) http.Header {
// To access HTTP trailers, we need to read the body to EOF.
_ = discard(call)
_, _ = discard(call)
return call.ResponseTrailer()
}
}
Expand Down

0 comments on commit df2f8da

Please sign in to comment.