Skip to content

Commit

Permalink
Merge pull request #42 from multiformats/feat/simopen
Browse files Browse the repository at this point in the history
Implement simultaneous open extension
  • Loading branch information
aarshkshah1992 committed Feb 12, 2021
2 parents 1587532 + 4c3567f commit 4661b85
Show file tree
Hide file tree
Showing 3 changed files with 396 additions and 3 deletions.
217 changes: 214 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@ package multistream

import (
"bytes"
"crypto/rand"
"encoding/binary"
"errors"
"io"
"strconv"
"strings"
)

// ErrNotSupported is the error returned when the muxer does not support
Expand All @@ -14,6 +18,12 @@ var ErrNotSupported = errors.New("protocol not supported")
// specified.
var ErrNoProtocols = errors.New("no protocols specified")

const (
tieBreakerPrefix = "select:"
initiator = "initiator"
responder = "responder"
)

// SelectProtoOrFail performs the initial multistream handshake
// to inform the muxer of the protocol that will be used to communicate
// on this ReadWriteCloser. It returns an error if, for example,
Expand All @@ -22,8 +32,10 @@ func SelectProtoOrFail(proto string, rwc io.ReadWriteCloser) error {
errCh := make(chan error, 1)
go func() {
var buf bytes.Buffer
delimWrite(&buf, []byte(ProtocolID))
delimWrite(&buf, []byte(proto))
if err := delitmWriteAll(&buf, []byte(ProtocolID), []byte(proto)); err != nil {
errCh <- err
return
}
_, err := io.Copy(rwc, &buf)
errCh <- err
}()
Expand Down Expand Up @@ -61,7 +73,80 @@ func SelectOneOf(protos []string, rwc io.ReadWriteCloser) (string, error) {
default:
return "", err
}
for _, p := range protos[1:] {
return selectProtosOrFail(protos[1:], rwc)
}

// Performs protocol negotiation with the simultaneous open extension; the returned boolean
// indicator will be true if we should act as a server.
func SelectWithSimopenOrFail(protos []string, rwc io.ReadWriteCloser) (string, bool, error) {
if len(protos) == 0 {
return "", false, ErrNoProtocols
}

werrCh := make(chan error, 1)
go func() {
var buf bytes.Buffer
if err := delitmWriteAll(&buf, []byte(ProtocolID), []byte("iamclient"), []byte(protos[0])); err != nil {
werrCh <- err
return
}

_, err := io.Copy(rwc, &buf)
werrCh <- err
}()

err := readMultistreamHeader(rwc)
if err != nil {
return "", false, err
}

tok, err := ReadNextToken(rwc)
if err != nil {
return "", false, err
}

if err = <-werrCh; err != nil {
return "", false, err
}

switch tok {
case "iamclient":
// simultaneous open
return simOpen(protos, rwc)

case "na":
// client open
proto, err := clientOpen(protos, rwc)
if err != nil {
return "", false, err
}

return proto, false, nil

default:
return "", false, errors.New("unexpected response: " + tok)
}
}

func clientOpen(protos []string, rwc io.ReadWriteCloser) (string, error) {
// check to see if we selected the pipelined protocol
tok, err := ReadNextToken(rwc)
if err != nil {
return "", err
}

switch tok {
case protos[0]:
return tok, nil
case "na":
return selectProtosOrFail(protos[1:], rwc)
default:
return "", errors.New("unexpected response: " + tok)
}
}

func selectProtosOrFail(protos []string, rwc io.ReadWriteCloser) (string, error) {
for _, p := range protos {
err := trySelect(p, rwc)
switch err {
case nil:
Expand All @@ -74,6 +159,132 @@ func SelectOneOf(protos []string, rwc io.ReadWriteCloser) (string, error) {
return "", ErrNotSupported
}

func simOpen(protos []string, rwc io.ReadWriteCloser) (string, bool, error) {
randBytes := make([]byte, 8)
_, err := rand.Read(randBytes)
if err != nil {
return "", false, err
}
myNonce := binary.LittleEndian.Uint64(randBytes)

werrCh := make(chan error, 1)
go func() {
myselect := []byte(tieBreakerPrefix + strconv.FormatUint(myNonce, 10))
err := delimWriteBuffered(rwc, myselect)
werrCh <- err
}()

// skip exactly one protocol
// see https://github.com/multiformats/go-multistream/pull/42#discussion_r558757135
_, err = ReadNextToken(rwc)
if err != nil {
return "", false, err
}

// read the tie breaker nonce
tok, err := ReadNextToken(rwc)
if err != nil {
return "", false, err
}
if !strings.HasPrefix(tok, tieBreakerPrefix) {
return "", false, errors.New("tie breaker nonce not sent with the correct prefix")
}

if err = <-werrCh; err != nil {
return "", false, err
}

peerNonce, err := strconv.ParseUint(tok[len(tieBreakerPrefix):], 10, 64)
if err != nil {
return "", false, err
}

var iamserver bool

if peerNonce == myNonce {
return "", false, errors.New("failed client selection; identical nonces!")
}
iamserver = peerNonce > myNonce

var proto string
if iamserver {
proto, err = simOpenSelectServer(protos, rwc)
} else {
proto, err = simOpenSelectClient(protos, rwc)
}

return proto, iamserver, err
}

func simOpenSelectServer(protos []string, rwc io.ReadWriteCloser) (string, error) {
werrCh := make(chan error, 1)
go func() {
err := delimWriteBuffered(rwc, []byte(responder))
werrCh <- err
}()

tok, err := ReadNextToken(rwc)
if err != nil {
return "", err
}
if tok != initiator {
return "", errors.New("unexpected response: " + tok)
}
if err = <-werrCh; err != nil {
return "", err
}

for {
tok, err = ReadNextToken(rwc)

if err == io.EOF {
return "", ErrNotSupported
}

if err != nil {
return "", err
}

for _, p := range protos {
if tok == p {
err = delimWriteBuffered(rwc, []byte(p))
if err != nil {
return "", err
}

return p, nil
}
}

err = delimWriteBuffered(rwc, []byte("na"))
if err != nil {
return "", err
}
}

}

func simOpenSelectClient(protos []string, rwc io.ReadWriteCloser) (string, error) {
werrCh := make(chan error, 1)
go func() {
err := delimWriteBuffered(rwc, []byte(initiator))
werrCh <- err
}()

tok, err := ReadNextToken(rwc)
if err != nil {
return "", err
}
if tok != responder {
return "", errors.New("unexpected response: " + tok)
}
if err = <-werrCh; err != nil {
return "", err
}

return selectProtosOrFail(protos, rwc)
}

func handshake(rw io.ReadWriter) error {
errCh := make(chan error, 1)
go func() {
Expand Down
11 changes: 11 additions & 0 deletions multistream.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"bufio"
"bytes"
"errors"
"fmt"

"io"
"sync"
Expand Down Expand Up @@ -81,6 +82,16 @@ func delimWriteBuffered(w io.Writer, mes []byte) error {
return bw.Flush()
}

func delitmWriteAll(w io.Writer, messages ...[]byte) error {
for _, mes := range messages {
if err := delimWrite(w, mes); err != nil {
return fmt.Errorf("failed to write messages %s, err: %v ", string(mes), err)
}
}

return nil
}

func delimWrite(w io.Writer, mes []byte) error {
err := writeUvarint(w, uint64(len(mes)+1))
if err != nil {
Expand Down
Loading

0 comments on commit 4661b85

Please sign in to comment.