From 0a0b04deb5240f71c05b9d4ab5d0f67f39e34c2d Mon Sep 17 00:00:00 2001 From: Frederic BIDON Date: Thu, 7 Dec 2023 20:54:10 +0100 Subject: [PATCH] ByteStream consumer can write to interface{} * fix(ByteStreamConsumer): may now write into an interface which underlying type is []byte or string. * feat(ByteStreamConsumer): added support to io.ReaderFrom, preferred over io.Writer if available * feat(ByteStreamProducer): added support to io.WriterTo, preferred over io.Reader if available * refact(ByteStreamProducer): removed redundant case "string" and preferred the more general reflected case (supports aliased strings) * test: refactored ByteStream tests * test: added benchmark for bytestream.Consume * fixes #167 Signed-off-by: Frederic BIDON --- bytestream.go | 131 +++++++---- bytestream_test.go | 532 ++++++++++++++++++++++++++++++++------------- 2 files changed, 477 insertions(+), 186 deletions(-) diff --git a/bytestream.go b/bytestream.go index 0a6b8ec..4d9d26c 100644 --- a/bytestream.go +++ b/bytestream.go @@ -38,9 +38,16 @@ type byteStreamOpts struct { Close bool } -// ByteStreamConsumer creates a consumer for byte streams, -// takes a Writer/BinaryUnmarshaler interface or binary slice by reference, -// and reads from the provided reader +// ByteStreamConsumer creates a consumer for byte streams. +// +// The consumer consumes from a provided reader into the data passed by reference. +// +// Supported output underlying types and interfaces, prioritized in this order: +// - io.ReaderFrom (for maximum control) +// - io.Writer (performs io.Copy) +// - encoding.BinaryUnmarshaler +// - *string +// - *[]byte func ByteStreamConsumer(opts ...byteStreamOpt) Consumer { var vals byteStreamOpts for _, opt := range opts { @@ -51,10 +58,13 @@ func ByteStreamConsumer(opts ...byteStreamOpt) Consumer { if reader == nil { return errors.New("ByteStreamConsumer requires a reader") // early exit } + if data == nil { + return errors.New("nil destination for ByteStreamConsumer") + } closer := defaultCloser if vals.Close { - if cl, ok := reader.(io.Closer); ok { + if cl, isReaderCloser := reader.(io.Closer); isReaderCloser { closer = cl.Close } } @@ -62,34 +72,56 @@ func ByteStreamConsumer(opts ...byteStreamOpt) Consumer { _ = closer() }() - if wrtr, ok := data.(io.Writer); ok { - _, err := io.Copy(wrtr, reader) + if readerFrom, isReaderFrom := data.(io.ReaderFrom); isReaderFrom { + _, err := readerFrom.ReadFrom(reader) return err } - buf := new(bytes.Buffer) + if writer, isDataWriter := data.(io.Writer); isDataWriter { + _, err := io.Copy(writer, reader) + return err + } + + // buffers input before writing to data + var buf bytes.Buffer _, err := buf.ReadFrom(reader) if err != nil { return err } b := buf.Bytes() - if bu, ok := data.(encoding.BinaryUnmarshaler); ok { - return bu.UnmarshalBinary(b) - } + switch destinationPointer := data.(type) { + case encoding.BinaryUnmarshaler: + return destinationPointer.UnmarshalBinary(b) + case *any: + switch (*destinationPointer).(type) { + case string: + *destinationPointer = string(b) + + return nil + + case []byte: + *destinationPointer = b - if data != nil { - if str, ok := data.(*string); ok { - *str = string(b) return nil } - } + default: + // check for the underlying type to be pointer to []byte or string, + if ptr := reflect.TypeOf(data); ptr.Kind() != reflect.Ptr { + return errors.New("destination must be a pointer") + } - if t := reflect.TypeOf(data); data != nil && t.Kind() == reflect.Ptr { v := reflect.Indirect(reflect.ValueOf(data)) - if t = v.Type(); t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8 { + t := v.Type() + + switch { + case t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8: v.SetBytes(b) return nil + + case t.Kind() == reflect.String: + v.SetString(string(b)) + return nil } } @@ -98,21 +130,35 @@ func ByteStreamConsumer(opts ...byteStreamOpt) Consumer { }) } -// ByteStreamProducer creates a producer for byte streams, -// takes a Reader/BinaryMarshaler interface or binary slice, -// and writes to a writer (essentially a pipe) +// ByteStreamProducer creates a producer for byte streams. +// +// The producer takes input data then writes to an output writer (essentially as a pipe). +// +// Supported input underlying types and interfaces, prioritized in this order: +// - io.WriterTo (for maximum control) +// - io.Reader (performs io.Copy). A ReadCloser is closed before exiting. +// - encoding.BinaryMarshaler +// - error (writes as a string) +// - []byte +// - string +// - struct, other slices: writes as JSON func ByteStreamProducer(opts ...byteStreamOpt) Producer { var vals byteStreamOpts for _, opt := range opts { opt(&vals) } + return ProducerFunc(func(writer io.Writer, data interface{}) error { if writer == nil { return errors.New("ByteStreamProducer requires a writer") // early exit } + if data == nil { + return errors.New("nil destination for ByteStreamProducer") + } + closer := defaultCloser if vals.Close { - if cl, ok := writer.(io.Closer); ok { + if cl, isWriterCloser := writer.(io.Closer); isWriterCloser { closer = cl.Close } } @@ -120,46 +166,51 @@ func ByteStreamProducer(opts ...byteStreamOpt) Producer { _ = closer() }() - if rc, ok := data.(io.ReadCloser); ok { + if rc, isDataCloser := data.(io.ReadCloser); isDataCloser { defer rc.Close() } - if rdr, ok := data.(io.Reader); ok { - _, err := io.Copy(writer, rdr) + switch origin := data.(type) { + case io.WriterTo: + _, err := origin.WriteTo(writer) + return err + + case io.Reader: + _, err := io.Copy(writer, origin) return err - } - if bm, ok := data.(encoding.BinaryMarshaler); ok { - bytes, err := bm.MarshalBinary() + case encoding.BinaryMarshaler: + bytes, err := origin.MarshalBinary() if err != nil { return err } _, err = writer.Write(bytes) return err - } - - if data != nil { - if str, ok := data.(string); ok { - _, err := writer.Write([]byte(str)) - return err - } - if e, ok := data.(error); ok { - _, err := writer.Write([]byte(e.Error())) - return err - } + case error: + _, err := writer.Write([]byte(origin.Error())) + return err + default: v := reflect.Indirect(reflect.ValueOf(data)) - if t := v.Type(); t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8 { + t := v.Type() + + switch { + case t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8: _, err := writer.Write(v.Bytes()) return err - } - if t := v.Type(); t.Kind() == reflect.Struct || t.Kind() == reflect.Slice { + + case t.Kind() == reflect.String: + _, err := writer.Write([]byte(v.String())) + return err + + case t.Kind() == reflect.Struct || t.Kind() == reflect.Slice: b, err := swag.WriteJSON(data) if err != nil { return err } + _, err = writer.Write(b) return err } diff --git a/bytestream_test.go b/bytestream_test.go index af0c90c..608a408 100644 --- a/bytestream_test.go +++ b/bytestream_test.go @@ -2,8 +2,10 @@ package runtime import ( "bytes" + "crypto/rand" "errors" "fmt" + "io" "sync/atomic" "testing" @@ -12,120 +14,406 @@ import ( ) func TestByteStreamConsumer(t *testing.T) { - cons := ByteStreamConsumer() - const expected = "the data for the stream to be sent over the wire" + consumer := ByteStreamConsumer() + + t.Run("can consume as a WriterTo", func(t *testing.T) { + var dest io.WriterTo = new(bytes.Buffer) + require.NoError(t, consumer.Consume(bytes.NewBufferString(expected), dest)) + assert.Equal(t, expected, dest.(*bytes.Buffer).String()) + }) + + t.Run("can consume as a Writer", func(t *testing.T) { + dest := &closingWriter{} + require.NoError(t, consumer.Consume(bytes.NewBufferString(expected), dest)) + assert.Equal(t, expected, dest.String()) + }) + + t.Run("can consume as a string", func(t *testing.T) { + var dest string + require.NoError(t, consumer.Consume(bytes.NewBufferString(expected), &dest)) + assert.Equal(t, expected, dest) + }) + + t.Run("can consume as a binary unmarshaler", func(t *testing.T) { + var dest binaryUnmarshalDummy + require.NoError(t, consumer.Consume(bytes.NewBufferString(expected), &dest)) + assert.Equal(t, expected, dest.str) + }) + + t.Run("can consume as a binary slice", func(t *testing.T) { + var dest []byte + require.NoError(t, consumer.Consume(bytes.NewBufferString(expected), &dest)) + assert.Equal(t, expected, string(dest)) + }) + + t.Run("can consume as a type, with underlying as a binary slice", func(t *testing.T) { + type binarySlice []byte + var dest binarySlice + require.NoError(t, consumer.Consume(bytes.NewBufferString(expected), &dest)) + assert.Equal(t, expected, string(dest)) + }) + + t.Run("can consume as a type, with underlying as a string", func(t *testing.T) { + type aliasedString string + var dest aliasedString + require.NoError(t, consumer.Consume(bytes.NewBufferString(expected), &dest)) + assert.Equal(t, expected, string(dest)) + }) + + t.Run("can consume as an interface with underlying type []byte", func(t *testing.T) { + var dest interface{} = []byte{} + require.NoError(t, consumer.Consume(bytes.NewBufferString(expected), &dest)) + asBytes, ok := dest.([]byte) + require.True(t, ok) + assert.Equal(t, expected, string(asBytes)) + }) + + t.Run("can consume as an interface with underlying type string", func(t *testing.T) { + var dest interface{} = "x" + require.NoError(t, consumer.Consume(bytes.NewBufferString(expected), &dest)) + asString, ok := dest.(string) + require.True(t, ok) + assert.Equal(t, expected, asString) + }) + + t.Run("with CloseStream option", func(t *testing.T) { + t.Run("wants to close stream", func(t *testing.T) { + closingConsumer := ByteStreamConsumer(ClosesStream) + var dest bytes.Buffer + r := &closingReader{b: bytes.NewBufferString(expected)} + + require.NoError(t, closingConsumer.Consume(r, &dest)) + assert.Equal(t, expected, dest.String()) + assert.EqualValues(t, 1, r.calledClose) + }) + + t.Run("don't want to close stream", func(t *testing.T) { + nonClosingConsumer := ByteStreamConsumer() + var dest bytes.Buffer + r := &closingReader{b: bytes.NewBufferString(expected)} + + require.NoError(t, nonClosingConsumer.Consume(r, &dest)) + assert.Equal(t, expected, dest.String()) + assert.EqualValues(t, 0, r.calledClose) + }) + }) + + t.Run("error cases", func(t *testing.T) { + t.Run("passing in a nil slice will result in an error", func(t *testing.T) { + var dest *[]byte + require.Error(t, consumer.Consume(bytes.NewBufferString(expected), &dest)) + }) + + t.Run("passing a non-pointer will result in an error", func(t *testing.T) { + var dest []byte + require.Error(t, consumer.Consume(bytes.NewBufferString(expected), dest)) + }) + + t.Run("passing in nil destination result in an error", func(t *testing.T) { + require.Error(t, consumer.Consume(bytes.NewBufferString(expected), nil)) + }) + + t.Run("a reader who results in an error, will make it fail", func(t *testing.T) { + t.Run("binaryUnmarshal case", func(t *testing.T) { + var dest binaryUnmarshalDummy + require.Error(t, consumer.Consume(new(nopReader), &dest)) + }) + + t.Run("[]byte case", func(t *testing.T) { + var dest []byte + require.Error(t, consumer.Consume(new(nopReader), &dest)) + }) + }) + + t.Run("the reader cannot be nil", func(t *testing.T) { + var dest []byte + require.Error(t, consumer.Consume(nil, &dest)) + }) + }) +} - // can consume as a Writer - var b bytes.Buffer - require.NoError(t, cons.Consume(bytes.NewBufferString(expected), &b)) - assert.Equal(t, expected, b.String()) - - // can consume as a string - var s string - require.NoError(t, cons.Consume(bytes.NewBufferString(expected), &s)) - assert.Equal(t, expected, s) - - // can consume as an UnmarshalBinary - var bu binaryUnmarshalDummy - require.NoError(t, cons.Consume(bytes.NewBufferString(expected), &bu)) - assert.Equal(t, expected, bu.str) - - // can consume as a binary slice - var bs []byte - require.NoError(t, cons.Consume(bytes.NewBufferString(expected), &bs)) - assert.Equal(t, expected, string(bs)) - - type binarySlice []byte - var bs2 binarySlice - require.NoError(t, cons.Consume(bytes.NewBufferString(expected), &bs2)) - assert.Equal(t, expected, string(bs2)) - - // passing in a nilslice wil result in an error - var ns *[]byte - require.Error(t, cons.Consume(bytes.NewBufferString(expected), &ns)) - - // passing in nil wil result in an error as well - require.Error(t, cons.Consume(bytes.NewBufferString(expected), nil)) - - // a reader who results in an error, will make it fail - require.Error(t, cons.Consume(new(nopReader), &bu)) - require.Error(t, cons.Consume(new(nopReader), &bs)) +func BenchmarkByteStreamConsumer(b *testing.B) { + const bufferSize = 1000 + expected := make([]byte, bufferSize) + _, err := rand.Read(expected) + require.NoError(b, err) + consumer := ByteStreamConsumer() + input := bytes.NewReader(expected) + + b.Run("with writer", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + var dest bytes.Buffer + for i := 0; i < b.N; i++ { + err = consumer.Consume(input, &dest) + if err != nil { + b.Fatal(err) + } + _, _ = input.Seek(0, io.SeekStart) + dest.Reset() + } + }) + b.Run("with BinaryUnmarshal", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + var dest binaryUnmarshalDummyZeroAlloc + for i := 0; i < b.N; i++ { + err = consumer.Consume(input, &dest) + if err != nil { + b.Fatal(err) + } + _, _ = input.Seek(0, io.SeekStart) + } + }) + b.Run("with string", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + var dest string + for i := 0; i < b.N; i++ { + err = consumer.Consume(input, &dest) + if err != nil { + b.Fatal(err) + } + _, _ = input.Seek(0, io.SeekStart) + } + }) + b.Run("with []byte", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + var dest []byte + for i := 0; i < b.N; i++ { + err = consumer.Consume(input, &dest) + if err != nil { + b.Fatal(err) + } + _, _ = input.Seek(0, io.SeekStart) + } + }) + b.Run("with aliased string", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + type aliasedString string + var dest aliasedString + for i := 0; i < b.N; i++ { + err = consumer.Consume(input, &dest) + if err != nil { + b.Fatal(err) + } + _, _ = input.Seek(0, io.SeekStart) + } + }) + b.Run("with aliased []byte", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + type binarySlice []byte + var dest binarySlice + for i := 0; i < b.N; i++ { + err = consumer.Consume(input, &dest) + if err != nil { + b.Fatal(err) + } + _, _ = input.Seek(0, io.SeekStart) + } + }) +} - // the readers can also not be nil - require.Error(t, cons.Consume(nil, &bs)) +func TestByteStreamProducer(t *testing.T) { + const expected = "the data for the stream to be sent over the wire" + producer := ByteStreamProducer() + + t.Run("can produce from a WriterTo", func(t *testing.T) { + var rdr bytes.Buffer + var data io.WriterTo = bytes.NewBufferString(expected) + require.NoError(t, producer.Produce(&rdr, data)) + assert.Equal(t, expected, rdr.String()) + }) + + t.Run("can produce from a Reader", func(t *testing.T) { + var rdr bytes.Buffer + var data io.Reader = bytes.NewBufferString(expected) + require.NoError(t, producer.Produce(&rdr, data)) + assert.Equal(t, expected, rdr.String()) + }) + + t.Run("can produce from a binary marshaler", func(t *testing.T) { + var rdr bytes.Buffer + data := &binaryMarshalDummy{str: expected} + require.NoError(t, producer.Produce(&rdr, data)) + assert.Equal(t, expected, rdr.String()) + }) + + t.Run("can produce from a string", func(t *testing.T) { + var rdr bytes.Buffer + data := expected + require.NoError(t, producer.Produce(&rdr, data)) + assert.Equal(t, expected, rdr.String()) + }) + + t.Run("can produce from a []byte", func(t *testing.T) { + var rdr bytes.Buffer + data := []byte(expected) + require.NoError(t, producer.Produce(&rdr, data)) + assert.Equal(t, expected, rdr.String()) + rdr.Reset() + }) + + t.Run("can produce from an error", func(t *testing.T) { + var rdr bytes.Buffer + data := errors.New(expected) + require.NoError(t, producer.Produce(&rdr, data)) + assert.Equal(t, expected, rdr.String()) + }) + + t.Run("can produce from an aliased string", func(t *testing.T) { + var rdr bytes.Buffer + type aliasedString string + var data aliasedString = expected + require.NoError(t, producer.Produce(&rdr, data)) + assert.Equal(t, expected, rdr.String()) + }) + + t.Run("can produce from an interface with underlying type string", func(t *testing.T) { + var rdr bytes.Buffer + var data interface{} = expected + require.NoError(t, producer.Produce(&rdr, data)) + assert.Equal(t, expected, rdr.String()) + }) + + t.Run("can produce from an aliased []byte", func(t *testing.T) { + var rdr bytes.Buffer + type binarySlice []byte + var data binarySlice = []byte(expected) + require.NoError(t, producer.Produce(&rdr, data)) + assert.Equal(t, expected, rdr.String()) + }) + + t.Run("can produce from an interface with underling type []byte", func(t *testing.T) { + var rdr bytes.Buffer + var data interface{} = []byte(expected) + require.NoError(t, producer.Produce(&rdr, data)) + assert.Equal(t, expected, rdr.String()) + }) + + t.Run("can produce JSON from an arbitrary struct", func(t *testing.T) { + var rdr bytes.Buffer + type dummy struct { + Message string `json:"message,omitempty"` + } + data := dummy{Message: expected} + require.NoError(t, producer.Produce(&rdr, data)) + assert.Equal(t, fmt.Sprintf(`{"message":%q}`, expected), rdr.String()) + }) + + t.Run("can produce JSON from a pointer to an arbitrary struct", func(t *testing.T) { + var rdr bytes.Buffer + type dummy struct { + Message string `json:"message,omitempty"` + } + data := dummy{Message: expected} + require.NoError(t, producer.Produce(&rdr, data)) + assert.Equal(t, fmt.Sprintf(`{"message":%q}`, expected), rdr.String()) + }) + + t.Run("can produce JSON from an arbitrary slice", func(t *testing.T) { + var rdr bytes.Buffer + data := []string{expected} + require.NoError(t, producer.Produce(&rdr, data)) + assert.Equal(t, fmt.Sprintf(`[%q]`, expected), rdr.String()) + }) + + t.Run("with CloseStream option", func(t *testing.T) { + t.Run("wants to close stream", func(t *testing.T) { + closingProducer := ByteStreamProducer(ClosesStream) + r := &closingWriter{} + data := bytes.NewBufferString(expected) + + require.NoError(t, closingProducer.Produce(r, data)) + assert.Equal(t, expected, r.String()) + assert.EqualValues(t, 1, r.calledClose) + }) + + t.Run("don't want to close stream", func(t *testing.T) { + nonClosingProducer := ByteStreamProducer() + r := &closingWriter{} + data := bytes.NewBufferString(expected) + + require.NoError(t, nonClosingProducer.Produce(r, data)) + assert.Equal(t, expected, r.String()) + assert.EqualValues(t, 0, r.calledClose) + }) + + t.Run("always close data reader whenever possible", func(t *testing.T) { + nonClosingProducer := ByteStreamProducer() + r := &closingWriter{} + data := &closingReader{b: bytes.NewBufferString(expected)} + + require.NoError(t, nonClosingProducer.Produce(r, data)) + assert.Equal(t, expected, r.String()) + assert.EqualValuesf(t, 0, r.calledClose, "expected the input reader NOT to be closed") + assert.EqualValuesf(t, 1, data.calledClose, "expected the data reader to be closed") + }) + }) + + t.Run("error cases", func(t *testing.T) { + t.Run("MarshalBinary error gets propagated", func(t *testing.T) { + var rdr bytes.Buffer + data := new(binaryMarshalDummy) + require.Error(t, producer.Produce(&rdr, data)) + }) + + t.Run("nil data is never accepter", func(t *testing.T) { + var rdr bytes.Buffer + require.Error(t, producer.Produce(&rdr, nil)) + }) + + t.Run("nil readers should also never be acccepted", func(t *testing.T) { + data := expected + require.Error(t, producer.Produce(nil, data)) + }) + + t.Run("bool is an unsupported type", func(t *testing.T) { + var rdr bytes.Buffer + data := true + require.Error(t, producer.Produce(&rdr, data)) + }) + + t.Run("WriteJSON error gets propagated", func(t *testing.T) { + var rdr bytes.Buffer + type cannotMarshal struct { + X func() `json:"x"` + } + data := cannotMarshal{} + require.Error(t, producer.Produce(&rdr, data)) + }) + + }) } type binaryUnmarshalDummy struct { str string } -func (b *binaryUnmarshalDummy) UnmarshalBinary(bytes []byte) error { - if len(bytes) == 0 { +type binaryUnmarshalDummyZeroAlloc struct { + b []byte +} + +func (b *binaryUnmarshalDummy) UnmarshalBinary(data []byte) error { + if len(data) == 0 { return errors.New("no text given") } - b.str = string(bytes) + b.str = string(data) return nil } -func TestByteStreamProducer(t *testing.T) { - cons := ByteStreamProducer() - const expected = "the data for the stream to be sent over the wire" +func (b *binaryUnmarshalDummyZeroAlloc) UnmarshalBinary(data []byte) error { + if len(data) == 0 { + return errors.New("no text given") + } - var rdr bytes.Buffer - - // can produce using a reader - require.NoError(t, cons.Produce(&rdr, bytes.NewBufferString(expected))) - assert.Equal(t, expected, rdr.String()) - rdr.Reset() - - // can produce using a binary marshaller - require.NoError(t, cons.Produce(&rdr, &binaryMarshalDummy{expected})) - assert.Equal(t, expected, rdr.String()) - rdr.Reset() - - // string can also be used to produce - require.NoError(t, cons.Produce(&rdr, expected)) - assert.Equal(t, expected, rdr.String()) - rdr.Reset() - - // binary slices can also be used to produce - require.NoError(t, cons.Produce(&rdr, []byte(expected))) - assert.Equal(t, expected, rdr.String()) - rdr.Reset() - - // errors can also be used to produce - require.NoError(t, cons.Produce(&rdr, errors.New(expected))) - assert.Equal(t, expected, rdr.String()) - rdr.Reset() - - // structs can also be used to produce - require.NoError(t, cons.Produce(&rdr, Error{Message: expected})) - assert.Equal(t, fmt.Sprintf(`{"message":%q}`, expected), rdr.String()) - rdr.Reset() - - // struct pointers can also be used to produce - require.NoError(t, cons.Produce(&rdr, &Error{Message: expected})) - assert.Equal(t, fmt.Sprintf(`{"message":%q}`, expected), rdr.String()) - rdr.Reset() - - // slices can also be used to produce - require.NoError(t, cons.Produce(&rdr, []string{expected})) - assert.Equal(t, fmt.Sprintf(`[%q]`, expected), rdr.String()) - rdr.Reset() - - type binarySlice []byte - require.NoError(t, cons.Produce(&rdr, binarySlice(expected))) - assert.Equal(t, expected, rdr.String()) - rdr.Reset() - - // when binaryMarshal data is used, its potential error gets propagated - require.Error(t, cons.Produce(&rdr, new(binaryMarshalDummy))) - // nil data should never be accepted either - require.Error(t, cons.Produce(&rdr, nil)) - // nil readers should also never be acccepted - require.Error(t, cons.Produce(nil, bytes.NewBufferString(expected))) + b.b = data + return nil } type binaryMarshalDummy struct { @@ -175,51 +463,3 @@ func (c *closingReader) Read(p []byte) (n int, err error) { atomic.AddInt64(&c.calledRead, 1) return c.b.Read(p) } - -func TestBytestreamConsumer_Close(t *testing.T) { - cons := ByteStreamConsumer(ClosesStream) - expected := "the data for the stream to be sent over the wire" - - // can consume as a Writer - var b bytes.Buffer - r := &closingReader{b: bytes.NewBufferString(expected)} - require.NoError(t, cons.Consume(r, &b)) - assert.Equal(t, expected, b.String()) - assert.EqualValues(t, 1, r.calledClose) - - // can consume as a Writer - cons = ByteStreamConsumer() - b.Reset() - r = &closingReader{b: bytes.NewBufferString(expected)} - require.NoError(t, cons.Consume(r, &b)) - assert.Equal(t, expected, b.String()) - assert.EqualValues(t, 0, r.calledClose) -} - -func TestBytestreamProducer_Close(t *testing.T) { - cons := ByteStreamProducer(ClosesStream) - expected := "the data for the stream to be sent over the wire" - - // can consume as a Writer - r := &closingWriter{} - // can produce using a reader - require.NoError(t, cons.Produce(r, bytes.NewBufferString(expected))) - assert.Equal(t, expected, r.String()) - assert.EqualValues(t, 1, r.calledClose) - - cons = ByteStreamProducer() - r = &closingWriter{} - // can produce using a reader - require.NoError(t, cons.Produce(r, bytes.NewBufferString(expected))) - assert.Equal(t, expected, r.String()) - assert.EqualValues(t, 0, r.calledClose) - - cons = ByteStreamProducer() - r = &closingWriter{} - data := &closingReader{b: bytes.NewBufferString(expected)} - // can produce using a readcloser - require.NoError(t, cons.Produce(r, data)) - assert.Equal(t, expected, r.String()) - assert.EqualValues(t, 0, r.calledClose) - assert.EqualValues(t, 1, data.calledClose) -}