From 8264a962d1c21d52e8fca50af064c5535c3708d3 Mon Sep 17 00:00:00 2001 From: Chris Cotter Date: Thu, 28 Mar 2024 20:13:19 +0000 Subject: [PATCH] feat(storage): Implement io.WriterTo in Reader (#9659) * feat(storage): Implement io.WriterTo in Reader This allows the gRPC Reader to write directly into the application write buffer, saving a data copy. Users can get the benefit of this directly by explicitly calling Reader.WriteTo, but they can also benefit implicitly if they are calling io.Copy. A bunch of checksum logic had to be moved from the parent Reader into the transport Readers to make this work, since we need to update the checksum for every message read in WriteTo. * fix conf test object vars * fix review comments * fix EOF case. --- storage/grpc_client.go | 115 +++++++++++++++++++++++++++--- storage/http_client.go | 34 +++++++-- storage/reader.go | 26 ++++--- storage/retry_conformance_test.go | 48 +++++++++++-- 4 files changed, 189 insertions(+), 34 deletions(-) diff --git a/storage/grpc_client.go b/storage/grpc_client.go index c8c019da5137..e337213f03f2 100644 --- a/storage/grpc_client.go +++ b/storage/grpc_client.go @@ -19,6 +19,7 @@ import ( "encoding/base64" "errors" "fmt" + "hash/crc32" "io" "net/url" "os" @@ -1042,6 +1043,16 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange // This is the size of the entire object, even if only a range was requested. size := obj.GetSize() + // Only support checksums when reading an entire object, not a range. + var ( + wantCRC uint32 + checkCRC bool + ) + if checksums := msg.GetObjectChecksums(); checksums != nil && checksums.Crc32C != nil && params.offset == 0 && params.length < 0 { + wantCRC = checksums.GetCrc32C() + checkCRC = true + } + r = &Reader{ Attrs: ReaderObjectAttrs{ Size: size, @@ -1063,7 +1074,10 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange settings: s, zeroRange: params.length == 0, databuf: databuf, + wantCRC: wantCRC, + checkCRC: checkCRC, }, + checkCRC: checkCRC, } cr := msg.GetContentRange() @@ -1081,12 +1095,6 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange r.reader.Close() } - // Only support checksums when reading an entire object, not a range. - if checksums := msg.GetObjectChecksums(); checksums != nil && checksums.Crc32C != nil && params.offset == 0 && params.length < 0 { - r.wantCRC = checksums.GetCrc32C() - r.checkCRC = true - } - return r, nil } @@ -1464,12 +1472,34 @@ type gRPCReader struct { databuf []byte cancel context.CancelFunc settings *settings + checkCRC bool // should we check the CRC? + wantCRC uint32 // the CRC32c value the server sent in the header + gotCRC uint32 // running crc +} + +// Update the running CRC with the data in the slice, if CRC checking was enabled. +func (r *gRPCReader) updateCRC(b []byte) { + if r.checkCRC { + r.gotCRC = crc32.Update(r.gotCRC, crc32cTable, b) + } +} + +// Checks whether the CRC matches at the conclusion of a read, if CRC checking was enabled. +func (r *gRPCReader) runCRCCheck() error { + if r.checkCRC && r.gotCRC != r.wantCRC { + return fmt.Errorf("storage: bad CRC on read: got %d, want %d", r.gotCRC, r.wantCRC) + } + return nil } // Read reads bytes into the user's buffer from an open gRPC stream. func (r *gRPCReader) Read(p []byte) (int, error) { - // The entire object has been read by this reader, return EOF. + // The entire object has been read by this reader, check the checksum if + // necessary and return EOF. if r.size == r.seen || r.zeroRange { + if err := r.runCRCCheck(); err != nil { + return 0, err + } return 0, io.EOF } @@ -1478,7 +1508,7 @@ func (r *gRPCReader) Read(p []byte) (int, error) { // using the same reader. One encounters an error and the stream is closed // and then reopened while the other routine attempts to read from it. if r.stream == nil { - return 0, fmt.Errorf("reader has been closed") + return 0, fmt.Errorf("storage: reader has been closed") } var n int @@ -1487,6 +1517,7 @@ func (r *gRPCReader) Read(p []byte) (int, error) { if len(r.leftovers) > 0 { n = copy(p, r.leftovers) r.seen += int64(n) + r.updateCRC(p[:n]) r.leftovers = r.leftovers[n:] return n, nil } @@ -1512,10 +1543,78 @@ func (r *gRPCReader) Read(p []byte) (int, error) { r.leftovers = content[n:] } r.seen += int64(n) + r.updateCRC(p[:n]) return n, nil } +// WriteTo writes all the data requested by the Reader into w, implementing +// io.WriterTo. +func (r *gRPCReader) WriteTo(w io.Writer) (int64, error) { + // The entire object has been read by this reader, check the checksum if + // necessary and return nil. + if r.size == r.seen || r.zeroRange { + if err := r.runCRCCheck(); err != nil { + return 0, err + } + return 0, nil + } + + // No stream to read from, either never initialized or Close was called. + // Note: There is a potential concurrency issue if multiple routines are + // using the same reader. One encounters an error and the stream is closed + // and then reopened while the other routine attempts to read from it. + if r.stream == nil { + return 0, fmt.Errorf("storage: reader has been closed") + } + + // Track bytes written during before call. + var alreadySeen = r.seen + + // Write any leftovers to the stream. There will be some leftovers from the + // original NewRangeReader call. + if len(r.leftovers) > 0 { + // Write() will write the entire leftovers slice unless there is an error. + written, err := w.Write(r.leftovers) + r.seen += int64(written) + r.updateCRC(r.leftovers) + r.leftovers = nil + if err != nil { + return r.seen - alreadySeen, err + } + } + + // Loop and receive additional messages until the entire data is written. + for { + // Attempt to receive the next message on the stream. + // Will terminate with io.EOF once data has all come through. + // recv() handles stream reopening and retry logic so no need for retries here. + msg, err := r.recv() + if err != nil { + if err == io.EOF { + // We are done; check the checksum if necessary and return. + err = r.runCRCCheck() + } + return r.seen - alreadySeen, err + } + + // TODO: Determine if we need to capture incremental CRC32C for this + // chunk. The Object CRC32C checksum is captured when directed to read + // the entire Object. If directed to read a range, we may need to + // calculate the range's checksum for verification if the checksum is + // present in the response here. + // TODO: Figure out if we need to support decompressive transcoding + // https://cloud.google.com/storage/docs/transcoding. + written, err := w.Write(msg) + r.seen += int64(written) + r.updateCRC(msg) + if err != nil { + return r.seen - alreadySeen, err + } + } + +} + // Close cancels the read stream's context in order for it to be closed and // collected. func (r *gRPCReader) Close() error { diff --git a/storage/http_client.go b/storage/http_client.go index e3e0d761bb08..f75d93897d9d 100644 --- a/storage/http_client.go +++ b/storage/http_client.go @@ -19,6 +19,7 @@ import ( "encoding/base64" "errors" "fmt" + "hash/crc32" "io" "io/ioutil" "net/http" @@ -1218,9 +1219,12 @@ func (c *httpStorageClient) DeleteNotification(ctx context.Context, bucket strin } type httpReader struct { - body io.ReadCloser - seen int64 - reopen func(seen int64) (*http.Response, error) + body io.ReadCloser + seen int64 + reopen func(seen int64) (*http.Response, error) + checkCRC bool // should we check the CRC? + wantCRC uint32 // the CRC32c value the server sent in the header + gotCRC uint32 // running crc } func (r *httpReader) Read(p []byte) (int, error) { @@ -1229,7 +1233,22 @@ func (r *httpReader) Read(p []byte) (int, error) { m, err := r.body.Read(p[n:]) n += m r.seen += int64(m) - if err == nil || err == io.EOF { + if r.checkCRC { + r.gotCRC = crc32.Update(r.gotCRC, crc32cTable, p[:n]) + } + if err == nil { + return n, nil + } + if err == io.EOF { + // Check CRC here. It would be natural to check it in Close, but + // everybody defers Close on the assumption that it doesn't return + // anything worth looking at. + if r.checkCRC { + if r.gotCRC != r.wantCRC { + return n, fmt.Errorf("storage: bad CRC on read: got %d, want %d", + r.gotCRC, r.wantCRC) + } + } return n, err } // Read failed (likely due to connection issues), but we will try to reopen @@ -1435,11 +1454,12 @@ func parseReadResponse(res *http.Response, params *newRangeReaderParams, reopen Attrs: attrs, size: size, remain: remain, - wantCRC: crc, checkCRC: checkCRC, reader: &httpReader{ - reopen: reopen, - body: body, + reopen: reopen, + body: body, + wantCRC: crc, + checkCRC: checkCRC, }, }, nil } diff --git a/storage/reader.go b/storage/reader.go index 4673a68d0789..0b228a6a76c9 100644 --- a/storage/reader.go +++ b/storage/reader.go @@ -198,9 +198,7 @@ var emptyBody = ioutil.NopCloser(strings.NewReader("")) type Reader struct { Attrs ReaderObjectAttrs seen, remain, size int64 - checkCRC bool // should we check the CRC? - wantCRC uint32 // the CRC32c value the server sent in the header - gotCRC uint32 // running crc + checkCRC bool // Did we check the CRC? This is now only used by tests. reader io.ReadCloser ctx context.Context @@ -218,17 +216,17 @@ func (r *Reader) Read(p []byte) (int, error) { if r.remain != -1 { r.remain -= int64(n) } - if r.checkCRC { - r.gotCRC = crc32.Update(r.gotCRC, crc32cTable, p[:n]) - // Check CRC here. It would be natural to check it in Close, but - // everybody defers Close on the assumption that it doesn't return - // anything worth looking at. - if err == io.EOF { - if r.gotCRC != r.wantCRC { - return n, fmt.Errorf("storage: bad CRC on read: got %d, want %d", - r.gotCRC, r.wantCRC) - } - } + return n, err +} + +// WriteTo writes all the data from the Reader to w. Fulfills the io.WriterTo interface. +// This is called implicitly when calling io.Copy on a Reader. +func (r *Reader) WriteTo(w io.Writer) (int64, error) { + // This implicitly calls r.reader.WriteTo for gRPC only. JSON and XML don't have an + // implementation of WriteTo. + n, err := io.Copy(w, r.reader) + if r.remain != -1 { + r.remain -= int64(n) } return n, err } diff --git a/storage/retry_conformance_test.go b/storage/retry_conformance_test.go index 3f9dd618eba5..950b542e2c87 100644 --- a/storage/retry_conformance_test.go +++ b/storage/retry_conformance_test.go @@ -21,7 +21,6 @@ import ( "encoding/json" "fmt" "io" - "io/ioutil" "net/http" "net/url" "os" @@ -211,12 +210,25 @@ var methods = map[string][]retryFunc{ if err != nil { return err } - wr, err := io.Copy(ioutil.Discard, r) + wr, err := r.WriteTo(io.Discard) if got, want := wr, len(randomBytesToWrite); got != int64(want) { return fmt.Errorf("body length mismatch\ngot:\n%v\n\nwant:\n%v", got, want) } return err }, + func(ctx context.Context, c *Client, fs *resources, _ bool) error { + // This tests downloads by calling Reader.Read rather than Reader.WriteTo. + r, err := c.Bucket(fs.bucket.Name).Object(fs.object.Name).NewReader(ctx) + if err != nil { + return err + } + // Use ReadAll because it calls Read implicitly, not WriteTo. + b, err := io.ReadAll(r) + if got, want := len(b), len(randomBytesToWrite); got != want { + return fmt.Errorf("body length mismatch\ngot:\n%v\n\nwant:\n%v", got, want) + } + return err + }, func(ctx context.Context, c *Client, fs *resources, _ bool) error { // Test JSON reads. client, ok := c.tc.(*httpStorageClient) @@ -233,7 +245,7 @@ var methods = map[string][]retryFunc{ if err != nil { return err } - wr, err := io.Copy(ioutil.Discard, r) + wr, err := io.Copy(io.Discard, r) if got, want := wr, len(randomBytesToWrite); got != int64(want) { return fmt.Errorf("body length mismatch\ngot:\n%v\n\nwant:\n%v", got, want) } @@ -253,7 +265,7 @@ var methods = map[string][]retryFunc{ return err } defer r.Close() - data, err := ioutil.ReadAll(r) + data, err := io.ReadAll(r) if err != nil { return fmt.Errorf("failed to ReadAll, err: %v", err) } @@ -265,6 +277,32 @@ var methods = map[string][]retryFunc{ } return nil }, + func(ctx context.Context, c *Client, fs *resources, _ bool) error { + // Test download via Reader.WriteTo. + // Before running the test method, populate a large test object of 9 MiB. + objName := objectIDs.New() + if err := uploadTestObject(fs.bucket.Name, objName, randomBytes3MiB); err != nil { + return fmt.Errorf("failed to create 9 MiB large object pre test, err: %v", err) + } + // Download the large test object for the S8 download method group. + r, err := c.Bucket(fs.bucket.Name).Object(objName).NewReader(ctx) + if err != nil { + return err + } + defer r.Close() + var data bytes.Buffer + _, err = r.WriteTo(&data) + if err != nil { + return fmt.Errorf("failed to ReadAll, err: %v", err) + } + if got, want := data.Len(), 3*MiB; got != want { + return fmt.Errorf("body length mismatch\ngot:\n%v\n\nwant:\n%v", got, want) + } + if got, want := data.Bytes(), randomBytes3MiB; !bytes.Equal(got, want) { + return fmt.Errorf("body mismatch\ngot:\n%v\n\nwant:\n%v", got, want) + } + return nil + }, func(ctx context.Context, c *Client, fs *resources, _ bool) error { // Test JSON reads. // Before running the test method, populate a large test object. @@ -289,7 +327,7 @@ var methods = map[string][]retryFunc{ return err } defer r.Close() - data, err := ioutil.ReadAll(r) + data, err := io.ReadAll(r) if err != nil { return fmt.Errorf("failed to ReadAll, err: %v", err) }