From 06f9db1fd450edeb482ee7538075b1f183e872b9 Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Wed, 28 Aug 2019 12:25:24 -0700 Subject: [PATCH] ls: make ls more consistent with other protocols fixes #41 Note: To make the `Ls` helper actually _useful_, it now performs the handshake internally. Really, this library isn't built for interactive use while `ls` _is_ so the interfaces are going to be kind of wonky. --- client.go | 6 +++--- multistream.go | 40 ++++++++++++++++++++-------------------- multistream_test.go | 41 +++++++++++++++++++++-------------------- 3 files changed, 44 insertions(+), 43 deletions(-) diff --git a/client.go b/client.go index 9a8f15e..02e096e 100644 --- a/client.go +++ b/client.go @@ -74,13 +74,13 @@ func SelectOneOf(protos []string, rwc io.ReadWriteCloser) (string, error) { return "", ErrNotSupported } -func handshake(rwc io.ReadWriteCloser) error { +func handshake(rw io.ReadWriter) error { errCh := make(chan error, 1) go func() { - errCh <- delimWriteBuffered(rwc, []byte(ProtocolID)) + errCh <- delimWriteBuffered(rw, []byte(ProtocolID)) }() - if err := readMultistreamHeader(rwc); err != nil { + if err := readMultistreamHeader(rw); err != nil { return err } return <-errCh diff --git a/multistream.go b/multistream.go index 85fd23e..1d90a68 100644 --- a/multistream.go +++ b/multistream.go @@ -93,26 +93,34 @@ func delimWrite(w io.Writer, mes []byte) error { // Ls is a Multistream muxer command which returns the list of handler names // available on a muxer. func Ls(rw io.ReadWriter) ([]string, error) { - err := delimWriteBuffered(rw, []byte("ls")) + err := handshake(rw) + if err != nil { + return nil, err + } + err = delimWriteBuffered(rw, []byte("ls")) if err != nil { return nil, err } - n, err := binary.ReadUvarint(&byteReader{rw}) + response, err := lpReadBuf(rw) if err != nil { return nil, err } + r := bytes.NewReader(response) + var out []string - for i := uint64(0); i < n; i++ { - val, err := lpReadBuf(rw) - if err != nil { + for { + val, err := lpReadBuf(r) + switch err { + default: return nil, err + case io.EOF: + return out, nil + case nil: + out = append(out, string(val)) } - out = append(out, string(val)) } - - return out, nil } func fulltextMatch(s string) func(string) bool { @@ -337,11 +345,6 @@ func (msm *MultistreamMuxer) Ls(w io.Writer) error { buf := new(bytes.Buffer) msm.handlerlock.RLock() - err := writeUvarint(buf, uint64(len(msm.handlers))) - if err != nil { - return err - } - for _, h := range msm.handlers { err := delimWrite(buf, []byte(h.AddName)) if err != nil { @@ -351,13 +354,7 @@ func (msm *MultistreamMuxer) Ls(w io.Writer) error { } msm.handlerlock.RUnlock() - ll := make([]byte, 16) - nw := binary.PutUvarint(ll, uint64(buf.Len())) - - r := io.MultiReader(bytes.NewReader(ll[:nw]), buf) - - _, err = io.Copy(w, r) - return err + return delimWrite(w, buf.Bytes()) } // Handle performs protocol negotiation on a ReadWriteCloser @@ -418,6 +415,9 @@ func lpReadBuf(r io.Reader) ([]byte, error) { buf := make([]byte, length) _, err = io.ReadFull(r, buf) if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } return nil, err } diff --git a/multistream_test.go b/multistream_test.go index 5076209..918d43d 100644 --- a/multistream_test.go +++ b/multistream_test.go @@ -3,7 +3,6 @@ package multistream import ( "bytes" "crypto/rand" - "encoding/binary" "fmt" "io" "net" @@ -142,6 +141,7 @@ func TestNegLazyStressRead(t *testing.T) { rwc.Close() } }() + defer func() { close(listener) }() for i := 0; i < count; i++ { a, b := newPipe(t) @@ -568,11 +568,10 @@ func TestTooLargeMessage(t *testing.T) { } func TestLs(t *testing.T) { - // TODO: in go1.7, use subtests (t.Run(....) ) - subtestLs(nil)(t) - subtestLs([]string{"a"})(t) - subtestLs([]string{"a", "b", "c", "d", "e"})(t) - subtestLs([]string{"", "a"})(t) + t.Run("none", subtestLs(nil)) + t.Run("one", subtestLs([]string{"a"})) + t.Run("many", subtestLs([]string{"a", "b", "c", "d", "e"})) + t.Run("empty", subtestLs([]string{"", "a"})) } func subtestLs(protos []string) func(*testing.T) { @@ -584,25 +583,27 @@ func subtestLs(protos []string) func(*testing.T) { mset[p] = true } - buf := new(bytes.Buffer) - err := mr.Ls(buf) - if err != nil { - t.Fatal(err) - } + c1, c2 := net.Pipe() + done := make(chan struct{}) + go func() { + defer close(done) - n, err := binary.ReadUvarint(buf) - if err != nil { - t.Fatal(err) - } - - if int(n) != buf.Len() { - t.Fatal("length wasnt properly prefixed") - } + proto, _, err := mr.Negotiate(c2) + c2.Close() + if err != io.EOF { + t.Error(err) + } + if proto != "" { + t.Errorf("expected no proto, got %s", proto) + } + }() + defer func() { <-done }() - items, err := Ls(buf) + items, err := Ls(c1) if err != nil { t.Fatal(err) } + c1.Close() if len(items) != len(protos) { t.Fatal("got wrong number of protocols")