diff --git a/impl/initiating_test.go b/impl/initiating_test.go index 2e0af663..36785e7d 100644 --- a/impl/initiating_test.go +++ b/impl/initiating_test.go @@ -332,6 +332,8 @@ func TestDataTransferInitiating(t *testing.T) { }, } for testCase, verify := range testCases { + + // test for new protocol -> new protocol t.Run(testCase, func(t *testing.T) { h := &harness{} ctx, cancel := context.WithTimeout(ctx, 10*time.Second) diff --git a/impl/integration_test.go b/impl/integration_test.go index dc1ec4d9..43abf286 100644 --- a/impl/integration_test.go +++ b/impl/integration_test.go @@ -24,6 +24,7 @@ import ( cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/protocol" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -41,6 +42,20 @@ import ( const loremFile = "lorem.txt" +// nil means use the default protocols +// tests data transfer for the following protocol combinations: +// default protocol -> default protocols +// old protocol -> default protocols +// default protocols -> old protocol +var protocolsForTest = map[string]struct { + host1Protocols []protocol.ID + host2Protocols []protocol.ID +}{ + "(new -> new)": {nil, nil}, + "(old -> new, old)": {[]protocol.ID{datatransfer.ProtocolDataTransfer1_0}, nil}, + "(new, old -> old)": {nil, []protocol.ID{datatransfer.ProtocolDataTransfer1_0}}, +} + func TestRoundTrip(t *testing.T) { ctx := context.Background() testCases := map[string]struct { @@ -77,145 +92,147 @@ func TestRoundTrip(t *testing.T) { }, } for testCase, data := range testCases { - t.Run(testCase, func(t *testing.T) { - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() + for pname, ps := range protocolsForTest { + t.Run(testCase+pname, func(t *testing.T) { + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() - gsData := testutil.NewGraphsyncTestingData(ctx, t) - host1 := gsData.Host1 // initiator, data sender - host2 := gsData.Host2 // data recipient + gsData := testutil.NewGraphsyncTestingData(ctx, t, ps.host1Protocols, ps.host2Protocols) + host1 := gsData.Host1 // initiator, data sender + host2 := gsData.Host2 // data recipient - tp1 := gsData.SetupGSTransportHost1() - tp2 := gsData.SetupGSTransportHost2() + tp1 := gsData.SetupGSTransportHost1() + tp2 := gsData.SetupGSTransportHost2() - dt1, err := NewDataTransfer(gsData.DtDs1, gsData.DtNet1, tp1, gsData.StoredCounter1) - require.NoError(t, err) - err = dt1.Start(ctx) - require.NoError(t, err) - dt2, err := NewDataTransfer(gsData.DtDs2, gsData.DtNet2, tp2, gsData.StoredCounter2) - require.NoError(t, err) - err = dt2.Start(ctx) - require.NoError(t, err) + dt1, err := NewDataTransfer(gsData.DtDs1, gsData.DtNet1, tp1, gsData.StoredCounter1) + require.NoError(t, err) + err = dt1.Start(ctx) + require.NoError(t, err) + dt2, err := NewDataTransfer(gsData.DtDs2, gsData.DtNet2, tp2, gsData.StoredCounter2) + require.NoError(t, err) + err = dt2.Start(ctx) + require.NoError(t, err) - finished := make(chan struct{}, 2) - errChan := make(chan struct{}, 2) - opened := make(chan struct{}, 2) - sent := make(chan uint64, 21) - received := make(chan uint64, 21) - var subscriber datatransfer.Subscriber = func(event datatransfer.Event, channelState datatransfer.ChannelState) { - if event.Code == datatransfer.DataSent { - if channelState.Sent() > 0 { - sent <- channelState.Sent() + finished := make(chan struct{}, 2) + errChan := make(chan struct{}, 2) + opened := make(chan struct{}, 2) + sent := make(chan uint64, 21) + received := make(chan uint64, 21) + var subscriber datatransfer.Subscriber = func(event datatransfer.Event, channelState datatransfer.ChannelState) { + if event.Code == datatransfer.DataSent { + if channelState.Sent() > 0 { + sent <- channelState.Sent() + } } - } - if event.Code == datatransfer.DataReceived { - if channelState.Received() > 0 { - received <- channelState.Received() + if event.Code == datatransfer.DataReceived { + if channelState.Received() > 0 { + received <- channelState.Received() + } } - } - if channelState.Status() == datatransfer.Completed { - finished <- struct{}{} + if channelState.Status() == datatransfer.Completed { + finished <- struct{}{} + } + if event.Code == datatransfer.Error { + errChan <- struct{}{} + } + if event.Code == datatransfer.Open { + opened <- struct{}{} + } } - if event.Code == datatransfer.Error { - errChan <- struct{}{} + dt1.SubscribeToEvents(subscriber) + dt2.SubscribeToEvents(subscriber) + voucher := testutil.FakeDTType{Data: "applesauce"} + sv := testutil.NewStubbedValidator() + + var sourceDagService ipldformat.DAGService + if data.customSourceStore { + ds := dss.MutexWrap(datastore.NewMapDatastore()) + bs := bstore.NewBlockstore(namespace.Wrap(ds, datastore.NewKey("blockstore"))) + loader := storeutil.LoaderForBlockstore(bs) + storer := storeutil.StorerForBlockstore(bs) + sourceDagService = merkledag.NewDAGService(blockservice.New(bs, offline.Exchange(bs))) + err := dt1.RegisterTransportConfigurer(&testutil.FakeDTType{}, func(channelID datatransfer.ChannelID, testVoucher datatransfer.Voucher, transport datatransfer.Transport) { + fv, ok := testVoucher.(*testutil.FakeDTType) + if ok && fv.Data == voucher.Data { + gsTransport, ok := transport.(*tp.Transport) + if ok { + err := gsTransport.UseStore(channelID, loader, storer) + require.NoError(t, err) + } + } + }) + require.NoError(t, err) + } else { + sourceDagService = gsData.DagService1 } - if event.Code == datatransfer.Open { - opened <- struct{}{} + root, origBytes := testutil.LoadUnixFSFile(ctx, t, sourceDagService, loremFile) + rootCid := root.(cidlink.Link).Cid + + var destDagService ipldformat.DAGService + if data.customTargetStore { + ds := dss.MutexWrap(datastore.NewMapDatastore()) + bs := bstore.NewBlockstore(namespace.Wrap(ds, datastore.NewKey("blockstore"))) + loader := storeutil.LoaderForBlockstore(bs) + storer := storeutil.StorerForBlockstore(bs) + destDagService = merkledag.NewDAGService(blockservice.New(bs, offline.Exchange(bs))) + err := dt2.RegisterTransportConfigurer(&testutil.FakeDTType{}, func(channelID datatransfer.ChannelID, testVoucher datatransfer.Voucher, transport datatransfer.Transport) { + fv, ok := testVoucher.(*testutil.FakeDTType) + if ok && fv.Data == voucher.Data { + gsTransport, ok := transport.(*tp.Transport) + if ok { + err := gsTransport.UseStore(channelID, loader, storer) + require.NoError(t, err) + } + } + }) + require.NoError(t, err) + } else { + destDagService = gsData.DagService2 } - } - dt1.SubscribeToEvents(subscriber) - dt2.SubscribeToEvents(subscriber) - voucher := testutil.FakeDTType{Data: "applesauce"} - sv := testutil.NewStubbedValidator() - var sourceDagService ipldformat.DAGService - if data.customSourceStore { - ds := dss.MutexWrap(datastore.NewMapDatastore()) - bs := bstore.NewBlockstore(namespace.Wrap(ds, datastore.NewKey("blockstore"))) - loader := storeutil.LoaderForBlockstore(bs) - storer := storeutil.StorerForBlockstore(bs) - sourceDagService = merkledag.NewDAGService(blockservice.New(bs, offline.Exchange(bs))) - err := dt1.RegisterTransportConfigurer(&testutil.FakeDTType{}, func(channelID datatransfer.ChannelID, testVoucher datatransfer.Voucher, transport datatransfer.Transport) { - fv, ok := testVoucher.(*testutil.FakeDTType) - if ok && fv.Data == voucher.Data { - gsTransport, ok := transport.(*tp.Transport) - if ok { - err := gsTransport.UseStore(channelID, loader, storer) - require.NoError(t, err) - } - } - }) + var chid datatransfer.ChannelID + if data.isPull { + sv.ExpectSuccessPull() + require.NoError(t, dt1.RegisterVoucherType(&testutil.FakeDTType{}, sv)) + chid, err = dt2.OpenPullDataChannel(ctx, host1.ID(), &voucher, rootCid, gsData.AllSelector) + } else { + sv.ExpectSuccessPush() + require.NoError(t, dt2.RegisterVoucherType(&testutil.FakeDTType{}, sv)) + chid, err = dt1.OpenPushDataChannel(ctx, host2.ID(), &voucher, rootCid, gsData.AllSelector) + } require.NoError(t, err) - } else { - sourceDagService = gsData.DagService1 - } - root, origBytes := testutil.LoadUnixFSFile(ctx, t, sourceDagService, loremFile) - rootCid := root.(cidlink.Link).Cid - - var destDagService ipldformat.DAGService - if data.customTargetStore { - ds := dss.MutexWrap(datastore.NewMapDatastore()) - bs := bstore.NewBlockstore(namespace.Wrap(ds, datastore.NewKey("blockstore"))) - loader := storeutil.LoaderForBlockstore(bs) - storer := storeutil.StorerForBlockstore(bs) - destDagService = merkledag.NewDAGService(blockservice.New(bs, offline.Exchange(bs))) - err := dt2.RegisterTransportConfigurer(&testutil.FakeDTType{}, func(channelID datatransfer.ChannelID, testVoucher datatransfer.Voucher, transport datatransfer.Transport) { - fv, ok := testVoucher.(*testutil.FakeDTType) - if ok && fv.Data == voucher.Data { - gsTransport, ok := transport.(*tp.Transport) - if ok { - err := gsTransport.UseStore(channelID, loader, storer) - require.NoError(t, err) - } + opens := 0 + completes := 0 + sentIncrements := make([]uint64, 0, 21) + receivedIncrements := make([]uint64, 0, 21) + for opens < 2 || completes < 2 || len(sentIncrements) < 21 || len(receivedIncrements) < 21 { + select { + case <-ctx.Done(): + t.Fatal("Did not complete succcessful data transfer") + case <-finished: + completes++ + case <-opened: + opens++ + case sentIncrement := <-sent: + sentIncrements = append(sentIncrements, sentIncrement) + case receivedIncrement := <-received: + receivedIncrements = append(receivedIncrements, receivedIncrement) + case <-errChan: + t.Fatal("received error on data transfer") } - }) - require.NoError(t, err) - } else { - destDagService = gsData.DagService2 - } - - var chid datatransfer.ChannelID - if data.isPull { - sv.ExpectSuccessPull() - require.NoError(t, dt1.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - chid, err = dt2.OpenPullDataChannel(ctx, host1.ID(), &voucher, rootCid, gsData.AllSelector) - } else { - sv.ExpectSuccessPush() - require.NoError(t, dt2.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - chid, err = dt1.OpenPushDataChannel(ctx, host2.ID(), &voucher, rootCid, gsData.AllSelector) - } - require.NoError(t, err) - opens := 0 - completes := 0 - sentIncrements := make([]uint64, 0, 21) - receivedIncrements := make([]uint64, 0, 21) - for opens < 2 || completes < 2 || len(sentIncrements) < 21 || len(receivedIncrements) < 21 { - select { - case <-ctx.Done(): - t.Fatal("Did not complete succcessful data transfer") - case <-finished: - completes++ - case <-opened: - opens++ - case sentIncrement := <-sent: - sentIncrements = append(sentIncrements, sentIncrement) - case receivedIncrement := <-received: - receivedIncrements = append(receivedIncrements, receivedIncrement) - case <-errChan: - t.Fatal("received error on data transfer") } - } - require.Equal(t, sentIncrements, receivedIncrements) - testutil.VerifyHasFile(ctx, t, destDagService, root, origBytes) - if data.isPull { - assert.Equal(t, chid.Initiator, host2.ID()) - } else { - assert.Equal(t, chid.Initiator, host1.ID()) - } - }) - } + require.Equal(t, sentIncrements, receivedIncrements) + testutil.VerifyHasFile(ctx, t, destDagService, root, origBytes) + if data.isPull { + assert.Equal(t, chid.Initiator, host2.ID()) + } else { + assert.Equal(t, chid.Initiator, host1.ID()) + } + }) + } + } // } func TestMultipleRoundTripMultipleStores(t *testing.T) { @@ -237,7 +254,7 @@ func TestMultipleRoundTripMultipleStores(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - gsData := testutil.NewGraphsyncTestingData(ctx, t) + gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) host1 := gsData.Host1 // initiator, data sender host2 := gsData.Host2 // data recipient @@ -364,7 +381,7 @@ func TestManyReceiversAtOnce(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - gsData := testutil.NewGraphsyncTestingData(ctx, t) + gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) host1 := gsData.Host1 // initiator, data sender tp1 := gsData.SetupGSTransportHost1() @@ -499,7 +516,7 @@ func TestRoundTripCancelledRequest(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - gsData := testutil.NewGraphsyncTestingData(ctx, t) + gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) host1 := gsData.Host1 // initiator, data sender host2 := gsData.Host2 @@ -642,7 +659,7 @@ func TestSimulatedRetrievalFlow(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 4*time.Second) defer cancel() - gsData := testutil.NewGraphsyncTestingData(ctx, t) + gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) host1 := gsData.Host1 // initiator, data sender root := gsData.LoadUnixFSFile(t, false) @@ -771,7 +788,7 @@ func TestPauseAndResume(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - gsData := testutil.NewGraphsyncTestingData(ctx, t) + gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) host1 := gsData.Host1 // initiator, data sender host2 := gsData.Host2 // data recipient @@ -914,7 +931,7 @@ func TestUnrecognizedVoucherRoundTrip(t *testing.T) { // ctx, cancel := context.WithTimeout(ctx, 5*time.Second) // defer cancel() - gsData := testutil.NewGraphsyncTestingData(ctx, t) + gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) host1 := gsData.Host1 // initiator, data sender host2 := gsData.Host2 // data recipient @@ -985,7 +1002,7 @@ func TestDataTransferSubscribing(t *testing.T) { ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - gsData := testutil.NewGraphsyncTestingData(ctx, t) + gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) host2 := gsData.Host2 tp1 := gsData.SetupGSTransportHost1() @@ -1115,7 +1132,7 @@ func TestRespondingToPushGraphsyncRequests(t *testing.T) { ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - gsData := testutil.NewGraphsyncTestingData(ctx, t) + gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) host1 := gsData.Host1 // initiator and data sender host2 := gsData.Host2 // data recipient, makes graphsync request for data voucher := testutil.NewFakeDTType() @@ -1199,7 +1216,7 @@ func TestResponseHookWhenExtensionNotFound(t *testing.T) { ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - gsData := testutil.NewGraphsyncTestingData(ctx, t) + gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) host1 := gsData.Host1 // initiator and data sender host2 := gsData.Host2 // data recipient, makes graphsync request for data voucher := testutil.FakeDTType{Data: "applesauce"} @@ -1326,7 +1343,7 @@ func TestRespondingToPullGraphsyncRequests(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - gsData := testutil.NewGraphsyncTestingData(ctx, t) + gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) // setup receiving peer to just record message coming in gsr := &fakeGraphSyncReceiver{ diff --git a/impl/restart.go b/impl/restart.go index 4290775c..21b1fa85 100644 --- a/impl/restart.go +++ b/impl/restart.go @@ -3,7 +3,6 @@ package impl import ( "bytes" "context" - "fmt" cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/libp2p/go-libp2p-core/peer" @@ -39,7 +38,11 @@ func (m *manager) restartManagerPeerReceivePush(ctx context.Context, channel dat // send a libp2p message to the other peer asking to send a "restart push request" req := message.RestartExistingChannelRequest(channel.ChannelID()) - return m.dataTransferNetwork.SendMessage(ctx, channel.OtherPeer(), req) + if err := m.dataTransferNetwork.SendMessage(ctx, channel.OtherPeer(), req); err != nil { + return xerrors.Errorf("unable to send restart request: %w", err) + } + + return nil } func (m *manager) restartManagerPeerReceivePull(ctx context.Context, channel datatransfer.ChannelState) error { @@ -50,7 +53,11 @@ func (m *manager) restartManagerPeerReceivePull(ctx context.Context, channel dat req := message.RestartExistingChannelRequest(channel.ChannelID()) // send a libp2p message to the other peer asking to send a "restart pull request" - return m.dataTransferNetwork.SendMessage(ctx, channel.OtherPeer(), req) + if err := m.dataTransferNetwork.SendMessage(ctx, channel.OtherPeer(), req); err != nil { + return xerrors.Errorf("unable to send restart request: %w", err) + } + + return nil } func (m *manager) validateRestartVoucher(channel datatransfer.ChannelState, isPull bool) error { @@ -91,9 +98,7 @@ func (m *manager) openPushRestartChannel(ctx context.Context, channel datatransf } m.dataTransferNetwork.Protect(requestTo, chid.String()) if err := m.dataTransferNetwork.SendMessage(ctx, requestTo, req); err != nil { - err = fmt.Errorf("Unable to send request: %w", err) - _ = m.channels.Error(chid, err) - return err + return xerrors.Errorf("Unable to send restart request: %w", err) } return nil @@ -118,9 +123,7 @@ func (m *manager) openPullRestartChannel(ctx context.Context, channel datatransf } m.dataTransferNetwork.Protect(requestTo, chid.String()) if err := m.transport.OpenChannel(ctx, requestTo, chid, cidlink.Link{Cid: baseCid}, selector, channel.ReceivedCids(), req); err != nil { - err = fmt.Errorf("Unable to send request: %w", err) - _ = m.channels.Error(chid, err) - return err + return xerrors.Errorf("Unable to send open channel restart request: %w", err) } return nil diff --git a/impl/restart_integration_test.go b/impl/restart_integration_test.go index 30bafed2..a52d518d 100644 --- a/impl/restart_integration_test.go +++ b/impl/restart_integration_test.go @@ -436,7 +436,7 @@ func newRestartHarness(t *testing.T) *restartHarness { ctx, cancel := context.WithTimeout(ctx, 60*time.Second) // Setup host - gsData := testutil.NewGraphsyncTestingData(ctx, t) + gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) host1 := gsData.Host1 // initiator, data sender host2 := gsData.Host2 // data recipient peer1 := host1.ID() diff --git a/message.go b/message.go index 480cc964..f0339683 100644 --- a/message.go +++ b/message.go @@ -5,11 +5,21 @@ import ( "github.com/ipfs/go-cid" "github.com/ipld/go-ipld-prime" + "github.com/libp2p/go-libp2p-core/protocol" cborgen "github.com/whyrusleeping/cbor-gen" "github.com/filecoin-project/go-data-transfer/encoding" ) +var ( + // ProtocolDataTransfer1_1 is the protocol identifier for graphsync messages + ProtocolDataTransfer1_1 protocol.ID = "/fil/datatransfer/1.1.0" + + // ProtocolDataTransfer1_0 is the protocol identifier for legacy graphsync messages + // This protocol does NOT support the `Restart` functionality for data transfer channels. + ProtocolDataTransfer1_0 protocol.ID = "/fil/datatransfer/1.0.0" +) + // Message is a message for the data transfer protocol // (either request or response) that can serialize to a protobuf type Message interface { @@ -23,6 +33,7 @@ type Message interface { cborgen.CBORMarshaler cborgen.CBORUnmarshaler ToNet(w io.Writer) error + MessageForProtocol(targetProtocol protocol.ID) (newMsg Message, err error) } // Request is a response message for the data transfer protocol diff --git a/message/message.go b/message/message.go index 1a2ecf6d..4fa34793 100644 --- a/message/message.go +++ b/message/message.go @@ -1,195 +1,18 @@ package message import ( - "io" - - "github.com/ipfs/go-cid" - "github.com/ipld/go-ipld-prime" - cborgen "github.com/whyrusleeping/cbor-gen" - xerrors "golang.org/x/xerrors" - - datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/encoding" -) - -type messageType uint64 - -const ( - newMessage messageType = iota - restartMessage - updateMessage - cancelMessage - completeMessage - voucherMessage - voucherResultMessage - restartExistingChannelRequestMessage + "github.com/filecoin-project/go-data-transfer/message/message1_1" ) -// NewRequest generates a new request for the data transfer protocol -func NewRequest(id datatransfer.TransferID, isRestart bool, isPull bool, vtype datatransfer.TypeIdentifier, voucher encoding.Encodable, baseCid cid.Cid, selector ipld.Node) (datatransfer.Request, error) { - vbytes, err := encoding.Encode(voucher) - if err != nil { - return nil, xerrors.Errorf("Creating request: %w", err) - } - if baseCid == cid.Undef { - return nil, xerrors.Errorf("base CID must be defined") - } - selBytes, err := encoding.Encode(selector) - if err != nil { - return nil, xerrors.Errorf("Error encoding selector") - } - - var typ uint64 - if isRestart { - typ = uint64(restartMessage) - } else { - typ = uint64(newMessage) - } - - return &transferRequest{ - Type: typ, - Pull: isPull, - Vouch: &cborgen.Deferred{Raw: vbytes}, - Stor: &cborgen.Deferred{Raw: selBytes}, - BCid: &baseCid, - VTyp: vtype, - XferID: uint64(id), - }, nil -} - -// RestartExistingChannelRequest creates a request to ask the other side to restart an existing channel -func RestartExistingChannelRequest(channelId datatransfer.ChannelID) datatransfer.Request { - - return &transferRequest{Type: uint64(restartExistingChannelRequestMessage), - RestartChannel: channelId} -} - -// CancelRequest request generates a request to cancel an in progress request -func CancelRequest(id datatransfer.TransferID) datatransfer.Request { - return &transferRequest{ - Type: uint64(cancelMessage), - XferID: uint64(id), - } -} - -// UpdateRequest generates a new request update -func UpdateRequest(id datatransfer.TransferID, isPaused bool) datatransfer.Request { - return &transferRequest{ - Type: uint64(updateMessage), - Paus: isPaused, - XferID: uint64(id), - } -} - -// VoucherRequest generates a new request for the data transfer protocol -func VoucherRequest(id datatransfer.TransferID, vtype datatransfer.TypeIdentifier, voucher encoding.Encodable) (datatransfer.Request, error) { - vbytes, err := encoding.Encode(voucher) - if err != nil { - return nil, xerrors.Errorf("Creating request: %w", err) - } - return &transferRequest{ - Type: uint64(voucherMessage), - Vouch: &cborgen.Deferred{Raw: vbytes}, - VTyp: vtype, - XferID: uint64(id), - }, nil -} - -// RestartResponse builds a new Data Transfer response -func RestartResponse(id datatransfer.TransferID, accepted bool, isPaused bool, voucherResultType datatransfer.TypeIdentifier, voucherResult encoding.Encodable) (datatransfer.Response, error) { - vbytes, err := encoding.Encode(voucherResult) - if err != nil { - return nil, xerrors.Errorf("Creating request: %w", err) - } - return &transferResponse{ - Acpt: accepted, - Type: uint64(restartMessage), - Paus: isPaused, - XferID: uint64(id), - VTyp: voucherResultType, - VRes: &cborgen.Deferred{Raw: vbytes}, - }, nil -} - -// NewResponse builds a new Data Transfer response -func NewResponse(id datatransfer.TransferID, accepted bool, isPaused bool, voucherResultType datatransfer.TypeIdentifier, voucherResult encoding.Encodable) (datatransfer.Response, error) { - vbytes, err := encoding.Encode(voucherResult) - if err != nil { - return nil, xerrors.Errorf("Creating request: %w", err) - } - return &transferResponse{ - Acpt: accepted, - Type: uint64(newMessage), - Paus: isPaused, - XferID: uint64(id), - VTyp: voucherResultType, - VRes: &cborgen.Deferred{Raw: vbytes}, - }, nil -} - -// VoucherResultResponse builds a new response for a voucher result -func VoucherResultResponse(id datatransfer.TransferID, accepted bool, isPaused bool, voucherResultType datatransfer.TypeIdentifier, voucherResult encoding.Encodable) (datatransfer.Response, error) { - vbytes, err := encoding.Encode(voucherResult) - if err != nil { - return nil, xerrors.Errorf("Creating request: %w", err) - } - return &transferResponse{ - Acpt: accepted, - Type: uint64(voucherResultMessage), - Paus: isPaused, - XferID: uint64(id), - VTyp: voucherResultType, - VRes: &cborgen.Deferred{Raw: vbytes}, - }, nil -} - -// UpdateResponse returns a new update response -func UpdateResponse(id datatransfer.TransferID, isPaused bool) datatransfer.Response { - return &transferResponse{ - Type: uint64(updateMessage), - Paus: isPaused, - XferID: uint64(id), - } -} - -// CancelResponse makes a new cancel response message -func CancelResponse(id datatransfer.TransferID) datatransfer.Response { - return &transferResponse{ - Type: uint64(cancelMessage), - XferID: uint64(id), - } -} - -// CompleteResponse returns a new complete response message -func CompleteResponse(id datatransfer.TransferID, isAccepted bool, isPaused bool, voucherResultType datatransfer.TypeIdentifier, voucherResult encoding.Encodable) (datatransfer.Response, error) { - vbytes, err := encoding.Encode(voucherResult) - if err != nil { - return nil, xerrors.Errorf("Creating request: %w", err) - } - return &transferResponse{ - Type: uint64(completeMessage), - Acpt: isAccepted, - Paus: isPaused, - VTyp: voucherResultType, - VRes: &cborgen.Deferred{Raw: vbytes}, - XferID: uint64(id), - }, nil -} - -// FromNet can read a network stream to deserialize a GraphSyncMessage -func FromNet(r io.Reader) (datatransfer.Message, error) { - tresp := transferMessage{} - err := tresp.UnmarshalCBOR(r) - if err != nil { - return nil, err - } - - if (tresp.IsRequest() && tresp.Request == nil) || (!tresp.IsRequest() && tresp.Response == nil) { - return nil, xerrors.Errorf("invalid/malformed message") - } - - if tresp.IsRequest() { - return tresp.Request, nil - } - return tresp.Response, nil -} +var NewRequest = message1_1.NewRequest +var RestartExistingChannelRequest = message1_1.RestartExistingChannelRequest +var UpdateRequest = message1_1.UpdateRequest +var VoucherRequest = message1_1.VoucherRequest +var RestartResponse = message1_1.RestartResponse +var NewResponse = message1_1.NewResponse +var VoucherResultResponse = message1_1.VoucherResultResponse +var CancelResponse = message1_1.CancelResponse +var UpdateResponse = message1_1.UpdateResponse +var FromNet = message1_1.FromNet +var CompleteResponse = message1_1.CompleteResponse +var CancelRequest = message1_1.CancelRequest diff --git a/message/message1_0/message.go b/message/message1_0/message.go new file mode 100644 index 00000000..23856886 --- /dev/null +++ b/message/message1_0/message.go @@ -0,0 +1,57 @@ +package message1_0 + +import ( + "io" + + "github.com/ipfs/go-cid" + cbg "github.com/whyrusleeping/cbor-gen" + xerrors "golang.org/x/xerrors" + + datatransfer "github.com/filecoin-project/go-data-transfer" +) + +// NewTransferRequest creates a transfer request for the 1_0 Data Transfer Protocol. +func NewTransferRequest(bcid *cid.Cid, typ uint64, paus, part, pull bool, stor, vouch *cbg.Deferred, + vtyp datatransfer.TypeIdentifier, xferId uint64) datatransfer.Request { + return &transferRequest{ + BCid: bcid, + Type: typ, + Paus: paus, + Part: part, + Pull: pull, + Stor: stor, + Vouch: vouch, + VTyp: vtyp, + XferID: xferId, + } +} + +// NewTransferRequest creates a transfer response for the 1_0 Data Transfer Protocol. +func NewTransferResponse(typ uint64, acpt bool, paus bool, xferId uint64, vRes *cbg.Deferred, vtyp datatransfer.TypeIdentifier) datatransfer.Response { + return &transferResponse{ + Type: typ, + Acpt: acpt, + Paus: paus, + XferID: xferId, + VRes: vRes, + VTyp: vtyp, + } +} + +// FromNet can read a network stream to deserialize a GraphSyncMessage +func FromNet(r io.Reader) (datatransfer.Message, error) { + tresp := transferMessage{} + err := tresp.UnmarshalCBOR(r) + if err != nil { + return nil, err + } + + if (tresp.IsRequest() && tresp.Request == nil) || (!tresp.IsRequest() && tresp.Response == nil) { + return nil, xerrors.Errorf("invalid/malformed message") + } + + if tresp.IsRequest() { + return tresp.Request, nil + } + return tresp.Response, nil +} diff --git a/message/transfer_message.go b/message/message1_0/transfer_message.go similarity index 97% rename from message/transfer_message.go rename to message/message1_0/transfer_message.go index 303b29a3..4da04b1d 100644 --- a/message/transfer_message.go +++ b/message/message1_0/transfer_message.go @@ -1,4 +1,4 @@ -package message +package message1_0 import ( "io" diff --git a/message/transfer_message_cbor_gen.go b/message/message1_0/transfer_message_cbor_gen.go similarity index 90% rename from message/transfer_message_cbor_gen.go rename to message/message1_0/transfer_message_cbor_gen.go index f6211da0..ab1f2695 100644 --- a/message/transfer_message_cbor_gen.go +++ b/message/message1_0/transfer_message_cbor_gen.go @@ -1,6 +1,6 @@ // Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. -package message +package message1_0 import ( "fmt" @@ -28,12 +28,12 @@ func (t *transferMessage) MarshalCBOR(w io.Writer) error { return err } - // t.Request (message.transferRequest) (struct) + // t.Request (message1_0.transferRequest) (struct) if err := t.Request.MarshalCBOR(w); err != nil { return err } - // t.Response (message.transferResponse) (struct) + // t.Response (message1_0.transferResponse) (struct) if err := t.Response.MarshalCBOR(w); err != nil { return err } @@ -75,7 +75,7 @@ func (t *transferMessage) UnmarshalCBOR(r io.Reader) error { default: return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) } - // t.Request (message.transferRequest) (struct) + // t.Request (message1_0.transferRequest) (struct) { @@ -94,7 +94,7 @@ func (t *transferMessage) UnmarshalCBOR(r io.Reader) error { } } - // t.Response (message.transferResponse) (struct) + // t.Response (message1_0.transferResponse) (struct) { diff --git a/message/transfer_request.go b/message/message1_0/transfer_request.go similarity index 81% rename from message/transfer_request.go rename to message/message1_0/transfer_request.go index 3d39da18..9c426861 100644 --- a/message/transfer_request.go +++ b/message/message1_0/transfer_request.go @@ -1,4 +1,4 @@ -package message +package message1_0 import ( "bytes" @@ -8,11 +8,13 @@ import ( "github.com/ipld/go-ipld-prime" "github.com/ipld/go-ipld-prime/codec/dagcbor" basicnode "github.com/ipld/go-ipld-prime/node/basic" + "github.com/libp2p/go-libp2p-core/protocol" cbg "github.com/whyrusleeping/cbor-gen" xerrors "golang.org/x/xerrors" datatransfer "github.com/filecoin-project/go-data-transfer" "github.com/filecoin-project/go-data-transfer/encoding" + "github.com/filecoin-project/go-data-transfer/message/types" ) //go:generate cbor-gen-for transferRequest @@ -29,8 +31,6 @@ type transferRequest struct { Vouch *cbg.Deferred VTyp datatransfer.TypeIdentifier XferID uint64 - - RestartChannel datatransfer.ChannelID } // IsRequest always returns true in this case because this is a transfer request @@ -38,31 +38,25 @@ func (trq *transferRequest) IsRequest() bool { return true } -func (trq *transferRequest) IsRestart() bool { - return trq.Type == uint64(restartMessage) -} - -func (trq *transferRequest) IsRestartExistingChannelRequest() bool { - return trq.Type == uint64(restartExistingChannelRequestMessage) -} - -func (trq *transferRequest) RestartChannelId() (datatransfer.ChannelID, error) { - if !trq.IsRestartExistingChannelRequest() { - return datatransfer.ChannelID{}, xerrors.New("not a restart request") +func (trq *transferRequest) MessageForProtocol(targetProtocol protocol.ID) (datatransfer.Message, error) { + switch targetProtocol { + case datatransfer.ProtocolDataTransfer1_0: + return trq, nil + default: + return nil, xerrors.Errorf("protocol not supported") } - return trq.RestartChannel, nil } func (trq *transferRequest) IsNew() bool { - return trq.Type == uint64(newMessage) + return trq.Type == uint64(types.NewMessage) } func (trq *transferRequest) IsUpdate() bool { - return trq.Type == uint64(updateMessage) + return trq.Type == uint64(types.UpdateMessage) } func (trq *transferRequest) IsVoucher() bool { - return trq.Type == uint64(voucherMessage) || trq.Type == uint64(newMessage) + return trq.Type == uint64(types.VoucherMessage) || trq.Type == uint64(types.NewMessage) } func (trq *transferRequest) IsPaused() bool { @@ -120,7 +114,7 @@ func (trq *transferRequest) Selector() (ipld.Node, error) { // IsCancel returns true if this is a cancel request func (trq *transferRequest) IsCancel() bool { - return trq.Type == uint64(cancelMessage) + return trq.Type == uint64(types.CancelMessage) } // IsPartial returns true if this is a partial request @@ -138,3 +132,15 @@ func (trq *transferRequest) ToNet(w io.Writer) error { } return msg.MarshalCBOR(w) } + +func (trq *transferRequest) IsRestart() bool { + return false +} + +func (trq *transferRequest) IsRestartExistingChannelRequest() bool { + return false +} + +func (trq *transferRequest) RestartChannelId() (datatransfer.ChannelID, error) { + return datatransfer.ChannelID{}, xerrors.New("not supported") +} diff --git a/message/message1_0/transfer_request_cbor_gen.go b/message/message1_0/transfer_request_cbor_gen.go new file mode 100644 index 00000000..61c15dee --- /dev/null +++ b/message/message1_0/transfer_request_cbor_gen.go @@ -0,0 +1,243 @@ +// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. + +package message1_0 + +import ( + "fmt" + "io" + + datatransfer "github.com/filecoin-project/go-data-transfer" + cbg "github.com/whyrusleeping/cbor-gen" + xerrors "golang.org/x/xerrors" +) + +var _ = xerrors.Errorf + +var lengthBuftransferRequest = []byte{137} + +func (t *transferRequest) MarshalCBOR(w io.Writer) error { + if t == nil { + _, err := w.Write(cbg.CborNull) + return err + } + if _, err := w.Write(lengthBuftransferRequest); err != nil { + return err + } + + scratch := make([]byte, 9) + + // t.BCid (cid.Cid) (struct) + + if t.BCid == nil { + if _, err := w.Write(cbg.CborNull); err != nil { + return err + } + } else { + if err := cbg.WriteCidBuf(scratch, w, *t.BCid); err != nil { + return xerrors.Errorf("failed to write cid field t.BCid: %w", err) + } + } + + // t.Type (uint64) (uint64) + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.Type)); err != nil { + return err + } + + // t.Paus (bool) (bool) + if err := cbg.WriteBool(w, t.Paus); err != nil { + return err + } + + // t.Part (bool) (bool) + if err := cbg.WriteBool(w, t.Part); err != nil { + return err + } + + // t.Pull (bool) (bool) + if err := cbg.WriteBool(w, t.Pull); err != nil { + return err + } + + // t.Stor (typegen.Deferred) (struct) + if err := t.Stor.MarshalCBOR(w); err != nil { + return err + } + + // t.Vouch (typegen.Deferred) (struct) + if err := t.Vouch.MarshalCBOR(w); err != nil { + return err + } + + // t.VTyp (datatransfer.TypeIdentifier) (string) + if len(t.VTyp) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.VTyp was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.VTyp))); err != nil { + return err + } + if _, err := io.WriteString(w, string(t.VTyp)); err != nil { + return err + } + + // t.XferID (uint64) (uint64) + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.XferID)); err != nil { + return err + } + + return nil +} + +func (t *transferRequest) UnmarshalCBOR(r io.Reader) error { + *t = transferRequest{} + + br := cbg.GetPeeker(r) + scratch := make([]byte, 8) + + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajArray { + return fmt.Errorf("cbor input should be of type array") + } + + if extra != 9 { + return fmt.Errorf("cbor input had wrong number of fields") + } + + // t.BCid (cid.Cid) (struct) + + { + + b, err := br.ReadByte() + if err != nil { + return err + } + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { + return err + } + + c, err := cbg.ReadCid(br) + if err != nil { + return xerrors.Errorf("failed to read cid field t.BCid: %w", err) + } + + t.BCid = &c + } + + } + // t.Type (uint64) (uint64) + + { + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.Type = uint64(extra) + + } + // t.Paus (bool) (bool) + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajOther { + return fmt.Errorf("booleans must be major type 7") + } + switch extra { + case 20: + t.Paus = false + case 21: + t.Paus = true + default: + return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) + } + // t.Part (bool) (bool) + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajOther { + return fmt.Errorf("booleans must be major type 7") + } + switch extra { + case 20: + t.Part = false + case 21: + t.Part = true + default: + return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) + } + // t.Pull (bool) (bool) + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajOther { + return fmt.Errorf("booleans must be major type 7") + } + switch extra { + case 20: + t.Pull = false + case 21: + t.Pull = true + default: + return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) + } + // t.Stor (typegen.Deferred) (struct) + + { + + t.Stor = new(cbg.Deferred) + + if err := t.Stor.UnmarshalCBOR(br); err != nil { + return xerrors.Errorf("failed to read deferred field: %w", err) + } + } + // t.Vouch (typegen.Deferred) (struct) + + { + + t.Vouch = new(cbg.Deferred) + + if err := t.Vouch.UnmarshalCBOR(br); err != nil { + return xerrors.Errorf("failed to read deferred field: %w", err) + } + } + // t.VTyp (datatransfer.TypeIdentifier) (string) + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + t.VTyp = datatransfer.TypeIdentifier(sval) + } + // t.XferID (uint64) (uint64) + + { + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.XferID = uint64(extra) + + } + return nil +} diff --git a/message/transfer_response.go b/message/message1_0/transfer_response.go similarity index 73% rename from message/transfer_response.go rename to message/message1_0/transfer_response.go index c5d9daf2..7de7a381 100644 --- a/message/transfer_response.go +++ b/message/message1_0/transfer_response.go @@ -1,13 +1,15 @@ -package message +package message1_0 import ( "io" + "github.com/libp2p/go-libp2p-core/protocol" cbg "github.com/whyrusleeping/cbor-gen" xerrors "golang.org/x/xerrors" datatransfer "github.com/filecoin-project/go-data-transfer" "github.com/filecoin-project/go-data-transfer/encoding" + "github.com/filecoin-project/go-data-transfer/message/types" ) //go:generate cbor-gen-for transferResponse @@ -26,6 +28,19 @@ func (trsp *transferResponse) TransferID() datatransfer.TransferID { return datatransfer.TransferID(trsp.XferID) } +func (trsp *transferResponse) IsRestart() bool { + return false +} + +func (trsp *transferResponse) MessageForProtocol(targetProtocol protocol.ID) (datatransfer.Message, error) { + switch targetProtocol { + case datatransfer.ProtocolDataTransfer1_0: + return trsp, nil + default: + return nil, xerrors.Errorf("protocol not supported") + } +} + // IsRequest always returns false in this case because this is a transfer response func (trsp *transferResponse) IsRequest() bool { return false @@ -33,12 +48,12 @@ func (trsp *transferResponse) IsRequest() bool { // IsNew returns true if this is the first response sent func (trsp *transferResponse) IsNew() bool { - return trsp.Type == uint64(newMessage) + return trsp.Type == uint64(types.NewMessage) } // IsUpdate returns true if this response is an update func (trsp *transferResponse) IsUpdate() bool { - return trsp.Type == uint64(updateMessage) + return trsp.Type == uint64(types.UpdateMessage) } // IsPaused returns true if the responder is paused @@ -48,17 +63,17 @@ func (trsp *transferResponse) IsPaused() bool { // IsCancel returns true if the responder has cancelled this response func (trsp *transferResponse) IsCancel() bool { - return trsp.Type == uint64(cancelMessage) + return trsp.Type == uint64(types.CancelMessage) } // IsComplete returns true if the responder has completed this response func (trsp *transferResponse) IsComplete() bool { - return trsp.Type == uint64(completeMessage) + return trsp.Type == uint64(types.CompleteMessage) } func (trsp *transferResponse) IsVoucherResult() bool { - return trsp.Type == uint64(voucherResultMessage) || trsp.Type == uint64(newMessage) || trsp.Type == uint64(completeMessage) || - trsp.Type == uint64(restartMessage) + return trsp.Type == uint64(types.VoucherResultMessage) || trsp.Type == uint64(types.NewMessage) || + trsp.Type == uint64(types.CompleteMessage) } // Accepted returns true if the request is accepted in the response @@ -77,10 +92,6 @@ func (trsp *transferResponse) VoucherResult(decoder encoding.Decoder) (encoding. return decoder.DecodeFromCbor(trsp.VRes.Raw) } -func (trq *transferResponse) IsRestart() bool { - return trq.Type == uint64(restartMessage) -} - func (trsp *transferResponse) EmptyVoucherResult() bool { return trsp.VTyp == datatransfer.EmptyTypeIdentifier } diff --git a/message/transfer_response_cbor_gen.go b/message/message1_0/transfer_response_cbor_gen.go similarity index 99% rename from message/transfer_response_cbor_gen.go rename to message/message1_0/transfer_response_cbor_gen.go index ab86a0c9..f07de23c 100644 --- a/message/transfer_response_cbor_gen.go +++ b/message/message1_0/transfer_response_cbor_gen.go @@ -1,6 +1,6 @@ // Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. -package message +package message1_0 import ( "fmt" diff --git a/message/message1_1/message.go b/message/message1_1/message.go new file mode 100644 index 00000000..cf60dc4a --- /dev/null +++ b/message/message1_1/message.go @@ -0,0 +1,183 @@ +package message1_1 + +import ( + "io" + + "github.com/ipfs/go-cid" + "github.com/ipld/go-ipld-prime" + cborgen "github.com/whyrusleeping/cbor-gen" + xerrors "golang.org/x/xerrors" + + datatransfer "github.com/filecoin-project/go-data-transfer" + "github.com/filecoin-project/go-data-transfer/encoding" + "github.com/filecoin-project/go-data-transfer/message/types" +) + +// NewRequest generates a new request for the data transfer protocol +func NewRequest(id datatransfer.TransferID, isRestart bool, isPull bool, vtype datatransfer.TypeIdentifier, voucher encoding.Encodable, baseCid cid.Cid, selector ipld.Node) (datatransfer.Request, error) { + vbytes, err := encoding.Encode(voucher) + if err != nil { + return nil, xerrors.Errorf("Creating request: %w", err) + } + if baseCid == cid.Undef { + return nil, xerrors.Errorf("base CID must be defined") + } + selBytes, err := encoding.Encode(selector) + if err != nil { + return nil, xerrors.Errorf("Error encoding selector") + } + + var typ uint64 + if isRestart { + typ = uint64(types.RestartMessage) + } else { + typ = uint64(types.NewMessage) + } + + return &transferRequest1_1{ + Type: typ, + Pull: isPull, + Vouch: &cborgen.Deferred{Raw: vbytes}, + Stor: &cborgen.Deferred{Raw: selBytes}, + BCid: &baseCid, + VTyp: vtype, + XferID: uint64(id), + }, nil +} + +// RestartExistingChannelRequest creates a request to ask the other side to restart an existing channel +func RestartExistingChannelRequest(channelId datatransfer.ChannelID) datatransfer.Request { + + return &transferRequest1_1{Type: uint64(types.RestartExistingChannelRequestMessage), + RestartChannel: channelId} +} + +// CancelRequest request generates a request to cancel an in progress request +func CancelRequest(id datatransfer.TransferID) datatransfer.Request { + return &transferRequest1_1{ + Type: uint64(types.CancelMessage), + XferID: uint64(id), + } +} + +// UpdateRequest generates a new request update +func UpdateRequest(id datatransfer.TransferID, isPaused bool) datatransfer.Request { + return &transferRequest1_1{ + Type: uint64(types.UpdateMessage), + Paus: isPaused, + XferID: uint64(id), + } +} + +// VoucherRequest generates a new request for the data transfer protocol +func VoucherRequest(id datatransfer.TransferID, vtype datatransfer.TypeIdentifier, voucher encoding.Encodable) (datatransfer.Request, error) { + vbytes, err := encoding.Encode(voucher) + if err != nil { + return nil, xerrors.Errorf("Creating request: %w", err) + } + return &transferRequest1_1{ + Type: uint64(types.VoucherMessage), + Vouch: &cborgen.Deferred{Raw: vbytes}, + VTyp: vtype, + XferID: uint64(id), + }, nil +} + +// RestartResponse builds a new Data Transfer response +func RestartResponse(id datatransfer.TransferID, accepted bool, isPaused bool, voucherResultType datatransfer.TypeIdentifier, voucherResult encoding.Encodable) (datatransfer.Response, error) { + vbytes, err := encoding.Encode(voucherResult) + if err != nil { + return nil, xerrors.Errorf("Creating request: %w", err) + } + return &transferResponse1_1{ + Acpt: accepted, + Type: uint64(types.RestartMessage), + Paus: isPaused, + XferID: uint64(id), + VTyp: voucherResultType, + VRes: &cborgen.Deferred{Raw: vbytes}, + }, nil +} + +// NewResponse builds a new Data Transfer response +func NewResponse(id datatransfer.TransferID, accepted bool, isPaused bool, voucherResultType datatransfer.TypeIdentifier, voucherResult encoding.Encodable) (datatransfer.Response, error) { + vbytes, err := encoding.Encode(voucherResult) + if err != nil { + return nil, xerrors.Errorf("Creating request: %w", err) + } + return &transferResponse1_1{ + Acpt: accepted, + Type: uint64(types.NewMessage), + Paus: isPaused, + XferID: uint64(id), + VTyp: voucherResultType, + VRes: &cborgen.Deferred{Raw: vbytes}, + }, nil +} + +// VoucherResultResponse builds a new response for a voucher result +func VoucherResultResponse(id datatransfer.TransferID, accepted bool, isPaused bool, voucherResultType datatransfer.TypeIdentifier, voucherResult encoding.Encodable) (datatransfer.Response, error) { + vbytes, err := encoding.Encode(voucherResult) + if err != nil { + return nil, xerrors.Errorf("Creating request: %w", err) + } + return &transferResponse1_1{ + Acpt: accepted, + Type: uint64(types.VoucherResultMessage), + Paus: isPaused, + XferID: uint64(id), + VTyp: voucherResultType, + VRes: &cborgen.Deferred{Raw: vbytes}, + }, nil +} + +// UpdateResponse returns a new update response +func UpdateResponse(id datatransfer.TransferID, isPaused bool) datatransfer.Response { + return &transferResponse1_1{ + Type: uint64(types.UpdateMessage), + Paus: isPaused, + XferID: uint64(id), + } +} + +// CancelResponse makes a new cancel response message +func CancelResponse(id datatransfer.TransferID) datatransfer.Response { + return &transferResponse1_1{ + Type: uint64(types.CancelMessage), + XferID: uint64(id), + } +} + +// CompleteResponse returns a new complete response message +func CompleteResponse(id datatransfer.TransferID, isAccepted bool, isPaused bool, voucherResultType datatransfer.TypeIdentifier, voucherResult encoding.Encodable) (datatransfer.Response, error) { + vbytes, err := encoding.Encode(voucherResult) + if err != nil { + return nil, xerrors.Errorf("Creating request: %w", err) + } + return &transferResponse1_1{ + Type: uint64(types.CompleteMessage), + Acpt: isAccepted, + Paus: isPaused, + VTyp: voucherResultType, + VRes: &cborgen.Deferred{Raw: vbytes}, + XferID: uint64(id), + }, nil +} + +// FromNet can read a network stream to deserialize a GraphSyncMessage +func FromNet(r io.Reader) (datatransfer.Message, error) { + tresp := transferMessage1_1{} + err := tresp.UnmarshalCBOR(r) + if err != nil { + return nil, err + } + + if (tresp.IsRequest() && tresp.Request == nil) || (!tresp.IsRequest() && tresp.Response == nil) { + return nil, xerrors.Errorf("invalid/malformed message") + } + + if tresp.IsRequest() { + return tresp.Request, nil + } + return tresp.Response, nil +} diff --git a/message/message_test.go b/message/message1_1/message_test.go similarity index 86% rename from message/message_test.go rename to message/message1_1/message_test.go index 949ca8dd..378542bf 100644 --- a/message/message_test.go +++ b/message/message1_1/message_test.go @@ -1,4 +1,4 @@ -package message_test +package message1_1_test import ( "bytes" @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/require" datatransfer "github.com/filecoin-project/go-data-transfer" - . "github.com/filecoin-project/go-data-transfer/message" + "github.com/filecoin-project/go-data-transfer/message/message1_1" "github.com/filecoin-project/go-data-transfer/testutil" ) @@ -21,7 +21,7 @@ func TestNewRequest(t *testing.T) { isPull := true id := datatransfer.TransferID(rand.Int31()) voucher := testutil.NewFakeDTType() - request, err := NewRequest(id, false, isPull, voucher.Type(), voucher, baseCid, selector) + request, err := message1_1.NewRequest(id, false, isPull, voucher.Type(), voucher, baseCid, selector) require.NoError(t, err) assert.Equal(t, id, request.TransferID()) assert.False(t, request.IsCancel()) @@ -49,7 +49,7 @@ func TestRestartRequest(t *testing.T) { isPull := true id := datatransfer.TransferID(rand.Int31()) voucher := testutil.NewFakeDTType() - request, err := NewRequest(id, true, isPull, voucher.Type(), voucher, baseCid, selector) + request, err := message1_1.NewRequest(id, true, isPull, voucher.Type(), voucher, baseCid, selector) require.NoError(t, err) assert.Equal(t, id, request.TransferID()) assert.False(t, request.IsCancel()) @@ -76,12 +76,12 @@ func TestRestartExistingChannelRequest(t *testing.T) { tid := uint64(1) chid := datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: datatransfer.TransferID(tid)} - req := RestartExistingChannelRequest(chid) + req := message1_1.RestartExistingChannelRequest(chid) wbuf := new(bytes.Buffer) require.NoError(t, req.ToNet(wbuf)) - desMsg, err := FromNet(wbuf) + desMsg, err := message1_1.FromNet(wbuf) require.NoError(t, err) req, ok := (desMsg).(datatransfer.Request) require.True(t, ok) @@ -106,7 +106,7 @@ func TestTransferRequest_UnmarshalCBOR(t *testing.T) { // use ToNet / FromNet require.NoError(t, req.ToNet(wbuf)) - desMsg, err := FromNet(wbuf) + desMsg, err := message1_1.FromNet(wbuf) require.NoError(t, err) // Verify round-trip @@ -124,7 +124,7 @@ func TestTransferRequest_UnmarshalCBOR(t *testing.T) { func TestResponses(t *testing.T) { id := datatransfer.TransferID(rand.Int31()) voucherResult := testutil.NewFakeDTType() - response, err := NewResponse(id, false, true, voucherResult.Type(), voucherResult) // not accepted + response, err := message1_1.NewResponse(id, false, true, voucherResult.Type(), voucherResult) // not accepted require.NoError(t, err) assert.Equal(t, response.TransferID(), id) assert.False(t, response.Accepted()) @@ -147,7 +147,7 @@ func TestResponses(t *testing.T) { func TestTransferResponse_MarshalCBOR(t *testing.T) { id := datatransfer.TransferID(rand.Int31()) voucherResult := testutil.NewFakeDTType() - response, err := NewResponse(id, true, false, voucherResult.Type(), voucherResult) // accepted + response, err := message1_1.NewResponse(id, true, false, voucherResult.Type(), voucherResult) // accepted require.NoError(t, err) // sanity check that we can marshal data @@ -159,14 +159,14 @@ func TestTransferResponse_MarshalCBOR(t *testing.T) { func TestTransferResponse_UnmarshalCBOR(t *testing.T) { id := datatransfer.TransferID(rand.Int31()) voucherResult := testutil.NewFakeDTType() - response, err := NewResponse(id, true, false, voucherResult.Type(), voucherResult) // accepted + response, err := message1_1.NewResponse(id, true, false, voucherResult.Type(), voucherResult) // accepted require.NoError(t, err) wbuf := new(bytes.Buffer) require.NoError(t, response.ToNet(wbuf)) // verify round trip - desMsg, err := FromNet(wbuf) + desMsg, err := message1_1.FromNet(wbuf) require.NoError(t, err) assert.False(t, desMsg.IsRequest()) assert.True(t, desMsg.IsNew()) @@ -185,7 +185,7 @@ func TestTransferResponse_UnmarshalCBOR(t *testing.T) { func TestRequestCancel(t *testing.T) { id := datatransfer.TransferID(rand.Int31()) - req := CancelRequest(id) + req := message1_1.CancelRequest(id) require.Equal(t, req.TransferID(), id) require.True(t, req.IsRequest()) require.True(t, req.IsCancel()) @@ -194,7 +194,7 @@ func TestRequestCancel(t *testing.T) { wbuf := new(bytes.Buffer) require.NoError(t, req.ToNet(wbuf)) - deserialized, err := FromNet(wbuf) + deserialized, err := message1_1.FromNet(wbuf) require.NoError(t, err) deserializedRequest, ok := deserialized.(datatransfer.Request) @@ -207,7 +207,7 @@ func TestRequestCancel(t *testing.T) { func TestRequestUpdate(t *testing.T) { id := datatransfer.TransferID(rand.Int31()) - req := UpdateRequest(id, true) + req := message1_1.UpdateRequest(id, true) require.Equal(t, req.TransferID(), id) require.True(t, req.IsRequest()) require.False(t, req.IsCancel()) @@ -217,7 +217,7 @@ func TestRequestUpdate(t *testing.T) { wbuf := new(bytes.Buffer) require.NoError(t, req.ToNet(wbuf)) - deserialized, err := FromNet(wbuf) + deserialized, err := message1_1.FromNet(wbuf) require.NoError(t, err) deserializedRequest, ok := deserialized.(datatransfer.Request) @@ -231,7 +231,7 @@ func TestRequestUpdate(t *testing.T) { func TestUpdateResponse(t *testing.T) { id := datatransfer.TransferID(rand.Int31()) - response := UpdateResponse(id, true) // not accepted + response := message1_1.UpdateResponse(id, true) // not accepted assert.Equal(t, response.TransferID(), id) assert.False(t, response.Accepted()) assert.False(t, response.IsNew()) @@ -252,7 +252,7 @@ func TestUpdateResponse(t *testing.T) { func TestCancelResponse(t *testing.T) { id := datatransfer.TransferID(rand.Int31()) - response := CancelResponse(id) + response := message1_1.CancelResponse(id) assert.Equal(t, response.TransferID(), id) assert.False(t, response.IsNew()) assert.False(t, response.IsUpdate()) @@ -271,7 +271,7 @@ func TestCancelResponse(t *testing.T) { func TestCompleteResponse(t *testing.T) { id := datatransfer.TransferID(rand.Int31()) - response, err := CompleteResponse(id, true, true, datatransfer.EmptyTypeIdentifier, nil) + response, err := message1_1.CompleteResponse(id, true, true, datatransfer.EmptyTypeIdentifier, nil) require.NoError(t, err) assert.Equal(t, response.TransferID(), id) assert.False(t, response.IsNew()) @@ -298,13 +298,13 @@ func TestToNetFromNetEquivalency(t *testing.T) { accepted := false voucher := testutil.NewFakeDTType() voucherResult := testutil.NewFakeDTType() - request, err := NewRequest(id, false, isPull, voucher.Type(), voucher, baseCid, selector) + request, err := message1_1.NewRequest(id, false, isPull, voucher.Type(), voucher, baseCid, selector) require.NoError(t, err) buf := new(bytes.Buffer) err = request.ToNet(buf) require.NoError(t, err) require.Greater(t, buf.Len(), 0) - deserialized, err := FromNet(buf) + deserialized, err := message1_1.FromNet(buf) require.NoError(t, err) deserializedRequest, ok := deserialized.(datatransfer.Request) @@ -318,11 +318,11 @@ func TestToNetFromNetEquivalency(t *testing.T) { testutil.AssertEqualFakeDTVoucher(t, request, deserializedRequest) testutil.AssertEqualSelector(t, request, deserializedRequest) - response, err := NewResponse(id, accepted, false, voucherResult.Type(), voucherResult) + response, err := message1_1.NewResponse(id, accepted, false, voucherResult.Type(), voucherResult) require.NoError(t, err) err = response.ToNet(buf) require.NoError(t, err) - deserialized, err = FromNet(buf) + deserialized, err = message1_1.FromNet(buf) require.NoError(t, err) deserializedResponse, ok := deserialized.(datatransfer.Response) @@ -335,10 +335,10 @@ func TestToNetFromNetEquivalency(t *testing.T) { require.Equal(t, deserializedResponse.IsPaused(), response.IsPaused()) testutil.AssertEqualFakeDTVoucherResult(t, response, deserializedResponse) - request = CancelRequest(id) + request = message1_1.CancelRequest(id) err = request.ToNet(buf) require.NoError(t, err) - deserialized, err = FromNet(buf) + deserialized, err = message1_1.FromNet(buf) require.NoError(t, err) deserializedRequest, ok = deserialized.(datatransfer.Request) @@ -352,13 +352,13 @@ func TestToNetFromNetEquivalency(t *testing.T) { func TestFromNetMessageValidation(t *testing.T) { // craft request message with nil request struct buf := []byte{0x83, 0xf5, 0xf6, 0xf6} - msg, err := FromNet(bytes.NewBuffer(buf)) + msg, err := message1_1.FromNet(bytes.NewBuffer(buf)) assert.Error(t, err) assert.Nil(t, msg) // craft response message with nil response struct buf = []byte{0x83, 0xf4, 0xf6, 0xf6} - msg, err = FromNet(bytes.NewBuffer(buf)) + msg, err = message1_1.FromNet(bytes.NewBuffer(buf)) assert.Error(t, err) assert.Nil(t, msg) } @@ -369,5 +369,5 @@ func NewTestTransferRequest() (datatransfer.Request, error) { isPull := false id := datatransfer.TransferID(rand.Int31()) voucher := testutil.NewFakeDTType() - return NewRequest(id, false, isPull, voucher.Type(), voucher, bcid, selector) + return message1_1.NewRequest(id, false, isPull, voucher.Type(), voucher, bcid, selector) } diff --git a/message/message1_1/transfer_message.go b/message/message1_1/transfer_message.go new file mode 100644 index 00000000..c37decf9 --- /dev/null +++ b/message/message1_1/transfer_message.go @@ -0,0 +1,38 @@ +package message1_1 + +import ( + "io" + + datatransfer "github.com/filecoin-project/go-data-transfer" +) + +//go:generate cbor-gen-for transferMessage1_1 + +// transferMessage1_1 is the transfer message for the 1.1 Data Transfer Protocol. +type transferMessage1_1 struct { + IsRq bool + + Request *transferRequest1_1 + Response *transferResponse1_1 +} + +// ========= datatransfer.Message interface + +// IsRequest returns true if this message is a data request +func (tm *transferMessage1_1) IsRequest() bool { + return tm.IsRq +} + +// TransferID returns the TransferID of this message +func (tm *transferMessage1_1) TransferID() datatransfer.TransferID { + if tm.IsRequest() { + return tm.Request.TransferID() + } + return tm.Response.TransferID() +} + +// ToNet serializes a transfer message type. It is simply a wrapper for MarshalCBOR, to provide +// symmetry with FromNet +func (tm *transferMessage1_1) ToNet(w io.Writer) error { + return tm.MarshalCBOR(w) +} diff --git a/message/message1_1/transfer_message_cbor_gen.go b/message/message1_1/transfer_message_cbor_gen.go new file mode 100644 index 00000000..e83fc769 --- /dev/null +++ b/message/message1_1/transfer_message_cbor_gen.go @@ -0,0 +1,117 @@ +// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. + +package message1_1 + +import ( + "fmt" + "io" + + cbg "github.com/whyrusleeping/cbor-gen" + xerrors "golang.org/x/xerrors" +) + +var _ = xerrors.Errorf + +var lengthBuftransferMessage1_1 = []byte{131} + +func (t *transferMessage1_1) MarshalCBOR(w io.Writer) error { + if t == nil { + _, err := w.Write(cbg.CborNull) + return err + } + if _, err := w.Write(lengthBuftransferMessage1_1); err != nil { + return err + } + + // t.IsRq (bool) (bool) + if err := cbg.WriteBool(w, t.IsRq); err != nil { + return err + } + + // t.Request (message1_1.transferRequest1_1) (struct) + if err := t.Request.MarshalCBOR(w); err != nil { + return err + } + + // t.Response (message1_1.transferResponse1_1) (struct) + if err := t.Response.MarshalCBOR(w); err != nil { + return err + } + return nil +} + +func (t *transferMessage1_1) UnmarshalCBOR(r io.Reader) error { + *t = transferMessage1_1{} + + br := cbg.GetPeeker(r) + scratch := make([]byte, 8) + + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajArray { + return fmt.Errorf("cbor input should be of type array") + } + + if extra != 3 { + return fmt.Errorf("cbor input had wrong number of fields") + } + + // t.IsRq (bool) (bool) + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajOther { + return fmt.Errorf("booleans must be major type 7") + } + switch extra { + case 20: + t.IsRq = false + case 21: + t.IsRq = true + default: + return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) + } + // t.Request (message1_1.transferRequest1_1) (struct) + + { + + b, err := br.ReadByte() + if err != nil { + return err + } + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { + return err + } + t.Request = new(transferRequest1_1) + if err := t.Request.UnmarshalCBOR(br); err != nil { + return xerrors.Errorf("unmarshaling t.Request pointer: %w", err) + } + } + + } + // t.Response (message1_1.transferResponse1_1) (struct) + + { + + b, err := br.ReadByte() + if err != nil { + return err + } + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { + return err + } + t.Response = new(transferResponse1_1) + if err := t.Response.UnmarshalCBOR(br); err != nil { + return xerrors.Errorf("unmarshaling t.Response pointer: %w", err) + } + } + + } + return nil +} diff --git a/message/message1_1/transfer_request.go b/message/message1_1/transfer_request.go new file mode 100644 index 00000000..ba0c42a1 --- /dev/null +++ b/message/message1_1/transfer_request.go @@ -0,0 +1,170 @@ +package message1_1 + +import ( + "bytes" + "io" + + "github.com/ipfs/go-cid" + "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/codec/dagcbor" + basicnode "github.com/ipld/go-ipld-prime/node/basic" + "github.com/libp2p/go-libp2p-core/protocol" + cbg "github.com/whyrusleeping/cbor-gen" + xerrors "golang.org/x/xerrors" + + datatransfer "github.com/filecoin-project/go-data-transfer" + "github.com/filecoin-project/go-data-transfer/encoding" + "github.com/filecoin-project/go-data-transfer/message/message1_0" + "github.com/filecoin-project/go-data-transfer/message/types" +) + +//go:generate cbor-gen-for transferRequest1_1 + +// transferRequest1_1 is a struct for the 1.1 Data Transfer Protocol that fulfills the datatransfer.Request interface. +// its members are exported to be used by cbor-gen +type transferRequest1_1 struct { + BCid *cid.Cid + Type uint64 + Paus bool + Part bool + Pull bool + Stor *cbg.Deferred + Vouch *cbg.Deferred + VTyp datatransfer.TypeIdentifier + XferID uint64 + + RestartChannel datatransfer.ChannelID +} + +func (trq *transferRequest1_1) MessageForProtocol(targetProtocol protocol.ID) (datatransfer.Message, error) { + switch targetProtocol { + case datatransfer.ProtocolDataTransfer1_1: + return trq, nil + case datatransfer.ProtocolDataTransfer1_0: + if trq.IsRestart() || trq.IsRestartExistingChannelRequest() { + return nil, xerrors.New("restart not supported on 1.0") + } + + lreq := message1_0.NewTransferRequest( + trq.BCid, + trq.Type, + trq.Paus, + trq.Part, + trq.Pull, + trq.Stor, + trq.Vouch, + trq.VTyp, + trq.XferID, + ) + return lreq, nil + + default: + return nil, xerrors.Errorf("protocol not supported") + } +} + +// IsRequest always returns true in this case because this is a transfer request +func (trq *transferRequest1_1) IsRequest() bool { + return true +} + +func (trq *transferRequest1_1) IsRestart() bool { + return trq.Type == uint64(types.RestartMessage) +} + +func (trq *transferRequest1_1) IsRestartExistingChannelRequest() bool { + return trq.Type == uint64(types.RestartExistingChannelRequestMessage) +} + +func (trq *transferRequest1_1) RestartChannelId() (datatransfer.ChannelID, error) { + if !trq.IsRestartExistingChannelRequest() { + return datatransfer.ChannelID{}, xerrors.New("not a restart request") + } + return trq.RestartChannel, nil +} + +func (trq *transferRequest1_1) IsNew() bool { + return trq.Type == uint64(types.NewMessage) +} + +func (trq *transferRequest1_1) IsUpdate() bool { + return trq.Type == uint64(types.UpdateMessage) +} + +func (trq *transferRequest1_1) IsVoucher() bool { + return trq.Type == uint64(types.VoucherMessage) || trq.Type == uint64(types.NewMessage) +} + +func (trq *transferRequest1_1) IsPaused() bool { + return trq.Paus +} + +func (trq *transferRequest1_1) TransferID() datatransfer.TransferID { + return datatransfer.TransferID(trq.XferID) +} + +// ========= datatransfer.Request interface +// IsPull returns true if this is a data pull request +func (trq *transferRequest1_1) IsPull() bool { + return trq.Pull +} + +// VoucherType returns the Voucher ID +func (trq *transferRequest1_1) VoucherType() datatransfer.TypeIdentifier { + return trq.VTyp +} + +// Voucher returns the Voucher bytes +func (trq *transferRequest1_1) Voucher(decoder encoding.Decoder) (encoding.Encodable, error) { + if trq.Vouch == nil { + return nil, xerrors.New("No voucher present to read") + } + return decoder.DecodeFromCbor(trq.Vouch.Raw) +} + +func (trq *transferRequest1_1) EmptyVoucher() bool { + return trq.VTyp == datatransfer.EmptyTypeIdentifier +} + +// BaseCid returns the Base CID +func (trq *transferRequest1_1) BaseCid() cid.Cid { + if trq.BCid == nil { + return cid.Undef + } + return *trq.BCid +} + +// Selector returns the message Selector bytes +func (trq *transferRequest1_1) Selector() (ipld.Node, error) { + if trq.Stor == nil { + return nil, xerrors.New("No selector present to read") + } + builder := basicnode.Prototype.Any.NewBuilder() + reader := bytes.NewReader(trq.Stor.Raw) + err := dagcbor.Decoder(builder, reader) + if err != nil { + return nil, xerrors.Errorf("Error decoding selector: %w", err) + } + return builder.Build(), nil +} + +// IsCancel returns true if this is a cancel request +func (trq *transferRequest1_1) IsCancel() bool { + return trq.Type == uint64(types.CancelMessage) +} + +// IsPartial returns true if this is a partial request +func (trq *transferRequest1_1) IsPartial() bool { + return trq.Part +} + +// ToNet serializes a transfer request. It's a wrapper for MarshalCBOR to provide +// symmetry with FromNet +func (trq *transferRequest1_1) ToNet(w io.Writer) error { + msg := transferMessage1_1{ + IsRq: true, + Request: trq, + Response: nil, + } + return msg.MarshalCBOR(w) +} diff --git a/message/transfer_request_cbor_gen.go b/message/message1_1/transfer_request_cbor_gen.go similarity index 94% rename from message/transfer_request_cbor_gen.go rename to message/message1_1/transfer_request_cbor_gen.go index 5acaa79a..d63dad39 100644 --- a/message/transfer_request_cbor_gen.go +++ b/message/message1_1/transfer_request_cbor_gen.go @@ -1,6 +1,6 @@ // Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. -package message +package message1_1 import ( "fmt" @@ -13,14 +13,14 @@ import ( var _ = xerrors.Errorf -var lengthBuftransferRequest = []byte{138} +var lengthBuftransferRequest1_1 = []byte{138} -func (t *transferRequest) MarshalCBOR(w io.Writer) error { +func (t *transferRequest1_1) MarshalCBOR(w io.Writer) error { if t == nil { _, err := w.Write(cbg.CborNull) return err } - if _, err := w.Write(lengthBuftransferRequest); err != nil { + if _, err := w.Write(lengthBuftransferRequest1_1); err != nil { return err } @@ -94,8 +94,8 @@ func (t *transferRequest) MarshalCBOR(w io.Writer) error { return nil } -func (t *transferRequest) UnmarshalCBOR(r io.Reader) error { - *t = transferRequest{} +func (t *transferRequest1_1) UnmarshalCBOR(r io.Reader) error { + *t = transferRequest1_1{} br := cbg.GetPeeker(r) scratch := make([]byte, 8) diff --git a/message/message1_1/transfer_request_test.go b/message/message1_1/transfer_request_test.go new file mode 100644 index 00000000..ee1e0940 --- /dev/null +++ b/message/message1_1/transfer_request_test.go @@ -0,0 +1,69 @@ +package message1_1_test + +import ( + "math/rand" + "testing" + + basicnode "github.com/ipld/go-ipld-prime/node/basic" + "github.com/ipld/go-ipld-prime/traversal/selector/builder" + "github.com/stretchr/testify/require" + + datatransfer "github.com/filecoin-project/go-data-transfer" + "github.com/filecoin-project/go-data-transfer/message/message1_1" + "github.com/filecoin-project/go-data-transfer/testutil" +) + +func TestRequestMessageForProtocol(t *testing.T) { + baseCid := testutil.GenerateCids(1)[0] + selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() + isPull := true + id := datatransfer.TransferID(rand.Int31()) + voucher := testutil.NewFakeDTType() + + // for the new protocol + request, err := message1_1.NewRequest(id, false, isPull, voucher.Type(), voucher, baseCid, selector) + require.NoError(t, err) + + out, err := request.MessageForProtocol(datatransfer.ProtocolDataTransfer1_1) + require.NoError(t, err) + require.Equal(t, request, out) + + // for the old protocol + out, err = request.MessageForProtocol(datatransfer.ProtocolDataTransfer1_0) + require.NoError(t, err) + req, ok := out.(datatransfer.Request) + require.True(t, ok) + require.False(t, req.IsRestart()) + require.False(t, req.IsRestartExistingChannelRequest()) + require.Equal(t, baseCid, req.BaseCid()) + require.True(t, req.IsPull()) + n, err := req.Selector() + require.NoError(t, err) + require.Equal(t, selector, n) + require.Equal(t, voucher.Type(), req.VoucherType()) + + // random protocol + out, err = request.MessageForProtocol("RAND") + require.Error(t, err) + require.Nil(t, out) +} + +func TestRequestMessageForProtocolRestartDowngradeFails(t *testing.T) { + baseCid := testutil.GenerateCids(1)[0] + selector := builder.NewSelectorSpecBuilder(basicnode.Prototype.Any).Matcher().Node() + isPull := true + id := datatransfer.TransferID(rand.Int31()) + voucher := testutil.NewFakeDTType() + + request, err := message1_1.NewRequest(id, true, isPull, voucher.Type(), voucher, baseCid, selector) + require.NoError(t, err) + + out, err := request.MessageForProtocol(datatransfer.ProtocolDataTransfer1_0) + require.Nil(t, out) + require.EqualError(t, err, "restart not supported on 1.0") + + req2 := message1_1.RestartExistingChannelRequest(datatransfer.ChannelID{}) + out, err = req2.MessageForProtocol(datatransfer.ProtocolDataTransfer1_0) + require.Nil(t, out) + require.EqualError(t, err, "restart not supported on 1.0") +} diff --git a/message/message1_1/transfer_response.go b/message/message1_1/transfer_response.go new file mode 100644 index 00000000..d2a93876 --- /dev/null +++ b/message/message1_1/transfer_response.go @@ -0,0 +1,126 @@ +package message1_1 + +import ( + "io" + + "github.com/libp2p/go-libp2p-core/protocol" + cbg "github.com/whyrusleeping/cbor-gen" + xerrors "golang.org/x/xerrors" + + datatransfer "github.com/filecoin-project/go-data-transfer" + "github.com/filecoin-project/go-data-transfer/encoding" + "github.com/filecoin-project/go-data-transfer/message/message1_0" + "github.com/filecoin-project/go-data-transfer/message/types" +) + +//go:generate cbor-gen-for transferResponse1_1 + +// transferResponse1_1 is a private struct that satisfies the datatransfer.Response interface +// It is the response message for the Data Transfer 1.1 Protocol. +type transferResponse1_1 struct { + Type uint64 + Acpt bool + Paus bool + XferID uint64 + VRes *cbg.Deferred + VTyp datatransfer.TypeIdentifier +} + +func (trsp *transferResponse1_1) TransferID() datatransfer.TransferID { + return datatransfer.TransferID(trsp.XferID) +} + +// IsRequest always returns false in this case because this is a transfer response +func (trsp *transferResponse1_1) IsRequest() bool { + return false +} + +// IsNew returns true if this is the first response sent +func (trsp *transferResponse1_1) IsNew() bool { + return trsp.Type == uint64(types.NewMessage) +} + +// IsUpdate returns true if this response is an update +func (trsp *transferResponse1_1) IsUpdate() bool { + return trsp.Type == uint64(types.UpdateMessage) +} + +// IsPaused returns true if the responder is paused +func (trsp *transferResponse1_1) IsPaused() bool { + return trsp.Paus +} + +// IsCancel returns true if the responder has cancelled this response +func (trsp *transferResponse1_1) IsCancel() bool { + return trsp.Type == uint64(types.CancelMessage) +} + +// IsComplete returns true if the responder has completed this response +func (trsp *transferResponse1_1) IsComplete() bool { + return trsp.Type == uint64(types.CompleteMessage) +} + +func (trsp *transferResponse1_1) IsVoucherResult() bool { + return trsp.Type == uint64(types.VoucherResultMessage) || trsp.Type == uint64(types.NewMessage) || trsp.Type == uint64(types.CompleteMessage) || + trsp.Type == uint64(types.RestartMessage) +} + +// Accepted returns true if the request is accepted in the response +func (trsp *transferResponse1_1) Accepted() bool { + return trsp.Acpt +} + +func (trsp *transferResponse1_1) VoucherResultType() datatransfer.TypeIdentifier { + return trsp.VTyp +} + +func (trsp *transferResponse1_1) VoucherResult(decoder encoding.Decoder) (encoding.Encodable, error) { + if trsp.VRes == nil { + return nil, xerrors.New("No voucher present to read") + } + return decoder.DecodeFromCbor(trsp.VRes.Raw) +} + +func (trq *transferResponse1_1) IsRestart() bool { + return trq.Type == uint64(types.RestartMessage) +} + +func (trsp *transferResponse1_1) EmptyVoucherResult() bool { + return trsp.VTyp == datatransfer.EmptyTypeIdentifier +} + +func (trsp *transferResponse1_1) MessageForProtocol(targetProtocol protocol.ID) (datatransfer.Message, error) { + switch targetProtocol { + case datatransfer.ProtocolDataTransfer1_1: + return trsp, nil + case datatransfer.ProtocolDataTransfer1_0: + // this should never happen but dosen't hurt to have this here for sanity + if trsp.IsRestart() { + return nil, xerrors.New("restart not supported for 1.0 protocol") + } + + lresp := message1_0.NewTransferResponse( + trsp.Type, + trsp.Acpt, + trsp.Paus, + trsp.XferID, + trsp.VRes, + trsp.VTyp, + ) + + return lresp, nil + default: + return nil, xerrors.Errorf("protocol %s not supported", targetProtocol) + } +} + +// ToNet serializes a transfer response. It's a wrapper for MarshalCBOR to provide +// symmetry with FromNet +func (trsp *transferResponse1_1) ToNet(w io.Writer) error { + msg := transferMessage1_1{ + IsRq: false, + Request: nil, + Response: trsp, + } + return msg.MarshalCBOR(w) +} diff --git a/message/message1_1/transfer_response_cbor_gen.go b/message/message1_1/transfer_response_cbor_gen.go new file mode 100644 index 00000000..94e9e677 --- /dev/null +++ b/message/message1_1/transfer_response_cbor_gen.go @@ -0,0 +1,171 @@ +// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. + +package message1_1 + +import ( + "fmt" + "io" + + datatransfer "github.com/filecoin-project/go-data-transfer" + cbg "github.com/whyrusleeping/cbor-gen" + xerrors "golang.org/x/xerrors" +) + +var _ = xerrors.Errorf + +var lengthBuftransferResponse1_1 = []byte{134} + +func (t *transferResponse1_1) MarshalCBOR(w io.Writer) error { + if t == nil { + _, err := w.Write(cbg.CborNull) + return err + } + if _, err := w.Write(lengthBuftransferResponse1_1); err != nil { + return err + } + + scratch := make([]byte, 9) + + // t.Type (uint64) (uint64) + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.Type)); err != nil { + return err + } + + // t.Acpt (bool) (bool) + if err := cbg.WriteBool(w, t.Acpt); err != nil { + return err + } + + // t.Paus (bool) (bool) + if err := cbg.WriteBool(w, t.Paus); err != nil { + return err + } + + // t.XferID (uint64) (uint64) + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.XferID)); err != nil { + return err + } + + // t.VRes (typegen.Deferred) (struct) + if err := t.VRes.MarshalCBOR(w); err != nil { + return err + } + + // t.VTyp (datatransfer.TypeIdentifier) (string) + if len(t.VTyp) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.VTyp was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.VTyp))); err != nil { + return err + } + if _, err := io.WriteString(w, string(t.VTyp)); err != nil { + return err + } + return nil +} + +func (t *transferResponse1_1) UnmarshalCBOR(r io.Reader) error { + *t = transferResponse1_1{} + + br := cbg.GetPeeker(r) + scratch := make([]byte, 8) + + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajArray { + return fmt.Errorf("cbor input should be of type array") + } + + if extra != 6 { + return fmt.Errorf("cbor input had wrong number of fields") + } + + // t.Type (uint64) (uint64) + + { + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.Type = uint64(extra) + + } + // t.Acpt (bool) (bool) + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajOther { + return fmt.Errorf("booleans must be major type 7") + } + switch extra { + case 20: + t.Acpt = false + case 21: + t.Acpt = true + default: + return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) + } + // t.Paus (bool) (bool) + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajOther { + return fmt.Errorf("booleans must be major type 7") + } + switch extra { + case 20: + t.Paus = false + case 21: + t.Paus = true + default: + return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) + } + // t.XferID (uint64) (uint64) + + { + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.XferID = uint64(extra) + + } + // t.VRes (typegen.Deferred) (struct) + + { + + t.VRes = new(cbg.Deferred) + + if err := t.VRes.UnmarshalCBOR(br); err != nil { + return xerrors.Errorf("failed to read deferred field: %w", err) + } + } + // t.VTyp (datatransfer.TypeIdentifier) (string) + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + t.VTyp = datatransfer.TypeIdentifier(sval) + } + return nil +} diff --git a/message/message1_1/transfer_response_test.go b/message/message1_1/transfer_response_test.go new file mode 100644 index 00000000..9ea922d9 --- /dev/null +++ b/message/message1_1/transfer_response_test.go @@ -0,0 +1,49 @@ +package message1_1_test + +import ( + "math/rand" + "testing" + + "github.com/stretchr/testify/require" + + datatransfer "github.com/filecoin-project/go-data-transfer" + "github.com/filecoin-project/go-data-transfer/message/message1_1" + "github.com/filecoin-project/go-data-transfer/testutil" +) + +func TestResponseMessageForProtocol(t *testing.T) { + id := datatransfer.TransferID(rand.Int31()) + voucherResult := testutil.NewFakeDTType() + response, err := message1_1.NewResponse(id, false, true, voucherResult.Type(), voucherResult) // not accepted + require.NoError(t, err) + + // new protocol + out, err := response.MessageForProtocol(datatransfer.ProtocolDataTransfer1_1) + require.NoError(t, err) + require.Equal(t, response, out) + + // old protocol + out, err = response.MessageForProtocol(datatransfer.ProtocolDataTransfer1_0) + require.NoError(t, err) + resp, ok := (out).(datatransfer.Response) + require.True(t, ok) + require.True(t, resp.IsPaused()) + require.Equal(t, voucherResult.Type(), resp.VoucherResultType()) + require.True(t, resp.IsVoucherResult()) + + // random protocol + out, err = response.MessageForProtocol("RAND") + require.Error(t, err) + require.Nil(t, out) +} + +func TestResponseMessageForProtocolFail(t *testing.T) { + id := datatransfer.TransferID(rand.Int31()) + voucherResult := testutil.NewFakeDTType() + response, err := message1_1.RestartResponse(id, false, true, voucherResult.Type(), voucherResult) // not accepted + require.NoError(t, err) + + out, err := response.MessageForProtocol(datatransfer.ProtocolDataTransfer1_0) + require.Nil(t, out) + require.EqualError(t, err, "restart not supported for 1.0 protocol") +} diff --git a/message/types/message_types.go b/message/types/message_types.go new file mode 100644 index 00000000..3144df0a --- /dev/null +++ b/message/types/message_types.go @@ -0,0 +1,16 @@ +package types + +type MessageType uint64 + +// Always append at the end to avoid breaking backward compatibility for cbor messages +const ( + NewMessage MessageType = iota + UpdateMessage + CancelMessage + CompleteMessage + VoucherMessage + VoucherResultMessage + + RestartMessage + RestartExistingChannelRequestMessage +) diff --git a/network/interface.go b/network/interface.go index 999d680f..d6e2ffc9 100644 --- a/network/interface.go +++ b/network/interface.go @@ -4,16 +4,10 @@ import ( "context" "github.com/libp2p/go-libp2p-core/peer" - "github.com/libp2p/go-libp2p-core/protocol" datatransfer "github.com/filecoin-project/go-data-transfer" ) -var ( - // ProtocolDataTransfer is the protocol identifier for graphsync messages - ProtocolDataTransfer protocol.ID = "/fil/datatransfer/1.0.0" -) - // DataTransferNetwork provides network connectivity for GraphSync. type DataTransferNetwork interface { Protect(id peer.ID, tag string) diff --git a/network/libp2p_impl.go b/network/libp2p_impl.go index fcaeae0f..ebb82b98 100644 --- a/network/libp2p_impl.go +++ b/network/libp2p_impl.go @@ -17,6 +17,7 @@ import ( datatransfer "github.com/filecoin-project/go-data-transfer" "github.com/filecoin-project/go-data-transfer/message" + "github.com/filecoin-project/go-data-transfer/message/message1_0" ) var log = logging.Logger("data_transfer_network") @@ -27,9 +28,19 @@ const defaultMaxStreamOpenAttempts = 5 const defaultMinAttemptDuration = 1 * time.Second const defaultMaxAttemptDuration = 5 * time.Minute +var defaultDataTransferProtocols = []protocol.ID{datatransfer.ProtocolDataTransfer1_1, datatransfer.ProtocolDataTransfer1_0} + // Option is an option for configuring the libp2p storage market network type Option func(*libp2pDataTransferNetwork) +// DataTransferProtocols OVERWRITES the default libp2p protocols we use for data transfer with the given protocols. +func DataTransferProtocols(protocols []protocol.ID) Option { + return func(impl *libp2pDataTransferNetwork) { + impl.dtProtocols = nil + impl.dtProtocols = append(impl.dtProtocols, protocols...) + } +} + // RetryParameters changes the default parameters around connection reopening func RetryParameters(minDuration time.Duration, maxDuration time.Duration, attempts float64) Option { return func(impl *libp2pDataTransferNetwork) { @@ -47,6 +58,7 @@ func NewFromLibp2pHost(host host.Host, options ...Option) DataTransferNetwork { maxStreamOpenAttempts: defaultMaxStreamOpenAttempts, minAttemptDuration: defaultMinAttemptDuration, maxAttemptDuration: defaultMaxAttemptDuration, + dtProtocols: defaultDataTransferProtocols, } for _, option := range options { @@ -66,9 +78,10 @@ type libp2pDataTransferNetwork struct { maxStreamOpenAttempts float64 minAttemptDuration time.Duration maxAttemptDuration time.Duration + dtProtocols []protocol.ID } -func (impl *libp2pDataTransferNetwork) openStream(ctx context.Context, id peer.ID, protocol protocol.ID) (network.Stream, error) { +func (impl *libp2pDataTransferNetwork) openStream(ctx context.Context, id peer.ID, protocols ...protocol.ID) (network.Stream, error) { b := &backoff.Backoff{ Min: impl.minAttemptDuration, Max: impl.maxAttemptDuration, @@ -77,7 +90,8 @@ func (impl *libp2pDataTransferNetwork) openStream(ctx context.Context, id peer.I } for { - s, err := impl.host.NewStream(ctx, id, protocol) + // will use the first among the given protocols that the remote peer supports + s, err := impl.host.NewStream(ctx, id, protocols...) if err == nil { return s, err } @@ -96,11 +110,16 @@ func (dtnet *libp2pDataTransferNetwork) SendMessage( p peer.ID, outgoing datatransfer.Message) error { - s, err := dtnet.openStream(ctx, p, ProtocolDataTransfer) + s, err := dtnet.openStream(ctx, p, dtnet.dtProtocols...) if err != nil { return err } + outgoing, err = outgoing.MessageForProtocol(s.Protocol()) + if err != nil { + return xerrors.Errorf("failed to convert message for protocol: %w", err) + } + if err = msgToStream(ctx, s, outgoing); err != nil { if err2 := s.Reset(); err2 != nil { log.Error(err) @@ -112,12 +131,13 @@ func (dtnet *libp2pDataTransferNetwork) SendMessage( // TODO(https://github.com/libp2p/go-libp2p-net/issues/28): Avoid this goroutine. go helpers.AwaitEOF(s) // nolint: errcheck,gosec return s.Close() - } func (dtnet *libp2pDataTransferNetwork) SetDelegate(r Receiver) { dtnet.receiver = r - dtnet.host.SetStreamHandler(ProtocolDataTransfer, dtnet.handleNewStream) + for _, p := range dtnet.dtProtocols { + dtnet.host.SetStreamHandler(p, dtnet.handleNewStream) + } } func (dtnet *libp2pDataTransferNetwork) ConnectTo(ctx context.Context, p peer.ID) error { @@ -134,7 +154,14 @@ func (dtnet *libp2pDataTransferNetwork) handleNewStream(s network.Stream) { } for { - received, err := message.FromNet(s) + var received datatransfer.Message + var err error + if s.Protocol() == datatransfer.ProtocolDataTransfer1_1 { + received, err = message.FromNet(s) + } else { + received, err = message1_0.FromNet(s) + } + if err != nil { if err != io.EOF { s.Reset() // nolint: errcheck,gosec @@ -192,15 +219,17 @@ func msgToStream(ctx context.Context, s network.Stream, msg datatransfer.Message } switch s.Protocol() { - case ProtocolDataTransfer: - if err := msg.ToNet(s); err != nil { - log.Debugf("error: %s", err) - return err - } + case datatransfer.ProtocolDataTransfer1_1: + case datatransfer.ProtocolDataTransfer1_0: default: return fmt.Errorf("unrecognized protocol on remote: %s", s.Protocol()) } + if err := msg.ToNet(s); err != nil { + log.Debugf("error: %s", err) + return err + } + if err := s.SetWriteDeadline(time.Time{}); err != nil { log.Warnf("error resetting deadline: %s", err) } diff --git a/testutil/gstestdata.go b/testutil/gstestdata.go index 044980e8..e3450778 100644 --- a/testutil/gstestdata.go +++ b/testutil/gstestdata.go @@ -34,6 +34,7 @@ import ( "github.com/ipld/go-ipld-prime/traversal/selector" "github.com/ipld/go-ipld-prime/traversal/selector/builder" "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/protocol" mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" "github.com/stretchr/testify/require" @@ -85,7 +86,7 @@ type GraphsyncTestingData struct { } // NewGraphsyncTestingData returns a new GraphsyncTestingData instance -func NewGraphsyncTestingData(ctx context.Context, t *testing.T) *GraphsyncTestingData { +func NewGraphsyncTestingData(ctx context.Context, t *testing.T, host1Protocols []protocol.ID, host2Protocols []protocol.ID) *GraphsyncTestingData { gsData := &GraphsyncTestingData{} gsData.Ctx = ctx @@ -130,8 +131,19 @@ func NewGraphsyncTestingData(ctx context.Context, t *testing.T) *GraphsyncTestin gsData.GsNet1 = gsnet.NewFromLibp2pHost(gsData.Host1) gsData.GsNet2 = gsnet.NewFromLibp2pHost(gsData.Host2) - gsData.DtNet1 = network.NewFromLibp2pHost(gsData.Host1, network.RetryParameters(0, 0, 0)) - gsData.DtNet2 = network.NewFromLibp2pHost(gsData.Host2, network.RetryParameters(0, 0, 0)) + opts1 := []network.Option{network.RetryParameters(0, 0, 0)} + opts2 := []network.Option{network.RetryParameters(0, 0, 0)} + + if len(host1Protocols) != 0 { + opts1 = append(opts1, network.DataTransferProtocols(host1Protocols)) + } + + if len(host2Protocols) != 0 { + opts2 = append(opts2, network.DataTransferProtocols(host2Protocols)) + } + + gsData.DtNet1 = network.NewFromLibp2pHost(gsData.Host1, opts1...) + gsData.DtNet2 = network.NewFromLibp2pHost(gsData.Host2, opts2...) // create a selector for the whole UnixFS dag gsData.AllSelector = allSelector