diff --git a/channels/channels.go b/channels/channels.go index ad5ecb32..af9947c1 100644 --- a/channels/channels.go +++ b/channels/channels.go @@ -131,7 +131,7 @@ func (c *Channels) CreateNew(tid datatransfer.TransferID, baseCid cid.Cid, selec } // InProgress returns a list of in progress channels -func (c *Channels) InProgress(ctx context.Context) (map[datatransfer.ChannelID]datatransfer.ChannelState, error) { +func (c *Channels) InProgress() (map[datatransfer.ChannelID]datatransfer.ChannelState, error) { var internalChannels []internalChannelState err := c.statemachines.List(&internalChannels) if err != nil { @@ -261,14 +261,3 @@ func (c *Channels) send(chid datatransfer.ChannelID, code datatransfer.EventCode } return c.statemachines.Send(chid, code, args...) } - -func (c *Channels) sendSync(ctx context.Context, chid datatransfer.ChannelID, code datatransfer.EventCode, args ...interface{}) error { - has, err := c.statemachines.Has(chid) - if err != nil { - return err - } - if !has { - return ErrNotFound - } - return c.statemachines.SendSync(ctx, chid, code, args...) -} diff --git a/channels/channels_test.go b/channels/channels_test.go index c8601261..2782744e 100644 --- a/channels/channels_test.go +++ b/channels/channels_test.go @@ -70,7 +70,7 @@ func TestChannels(t *testing.T) { }) t.Run("in progress channels", func(t *testing.T) { - inProgress, err := channelList.InProgress(ctx) + inProgress, err := channelList.InProgress() require.NoError(t, err) require.Len(t, inProgress, 2) require.Contains(t, inProgress, datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}) diff --git a/go.mod b/go.mod index bcc99f3d..08b7f683 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/filecoin-project/go-storedcounter v0.0.0-20200421200003-1c99c62e8a5b github.com/hannahhoward/cbor-gen-for v0.0.0-20191218204337-9ab7b1bcc099 github.com/hannahhoward/go-pubsub v0.0.0-20200423002714-8d62886cc36e + github.com/hashicorp/go-multierror v1.1.0 github.com/ipfs/go-block-format v0.0.2 github.com/ipfs/go-blockservice v0.1.3 github.com/ipfs/go-cid v0.0.7 diff --git a/go.sum b/go.sum index 3922c1cb..e410bf95 100644 --- a/go.sum +++ b/go.sum @@ -95,6 +95,10 @@ github.com/hannahhoward/cbor-gen-for v0.0.0-20191218204337-9ab7b1bcc099 h1:vQqOW github.com/hannahhoward/cbor-gen-for v0.0.0-20191218204337-9ab7b1bcc099/go.mod h1:WVPCl0HO/0RAL5+vBH2GMxBomlxBF70MAS78+Lu1//k= github.com/hannahhoward/go-pubsub v0.0.0-20200423002714-8d62886cc36e h1:3YKHER4nmd7b5qy5t0GWDTwSn4OyRgfAXSmo6VnryBY= github.com/hannahhoward/go-pubsub v0.0.0-20200423002714-8d62886cc36e/go.mod h1:I8h3MITA53gN9OnWGCgaMa0JWVRdXthWw4M3CPM54OY= +github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.1.0 h1:B9UzwGQJehnUY1yNrnwREHc3fGbC2xefo8g4TbElacI= +github.com/hashicorp/go-multierror v1.1.0/go.mod h1:spPvp8C1qA32ftKqdAHm4hHTbPw+vmowP0z+KUhOZdA= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= diff --git a/impl/impl.go b/impl/impl.go index ace93b91..7d684f30 100644 --- a/impl/impl.go +++ b/impl/impl.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/hannahhoward/go-pubsub" + "github.com/hashicorp/go-multierror" "github.com/ipfs/go-cid" "github.com/ipfs/go-datastore" logging "github.com/ipfs/go-log/v2" @@ -101,8 +102,20 @@ func (m *manager) Start(ctx context.Context) error { } // Stop terminates all data transfers and ends processing -func (m *manager) Stop() error { - return nil +func (m *manager) Stop(ctx context.Context) error { + openChannels, err := m.channels.InProgress() + if err != nil { + return xerrors.Errorf("error getting channels in progress: %w", err) + } + + var result error + for chid := range openChannels { + if err := m.CloseDataTransferChannel(ctx, chid); err != nil { + result = multierror.Append(result, xerrors.Errorf("error closing channel with ID %v, err: %w", chid, err)) + } + } + + return result } // RegisterVoucherType registers a validator for the given voucher type @@ -266,7 +279,7 @@ func (m *manager) SubscribeToEvents(subscriber datatransfer.Subscriber) datatran // get all in progress transfers func (m *manager) InProgressChannels(ctx context.Context) (map[datatransfer.ChannelID]datatransfer.ChannelState, error) { - return m.channels.InProgress(ctx) + return m.channels.InProgress() } // RegisterRevalidator registers a revalidator for the given voucher type diff --git a/manager.go b/manager.go index 1658f0c6..64210f80 100644 --- a/manager.go +++ b/manager.go @@ -58,7 +58,7 @@ type Manager interface { Start(ctx context.Context) error // Stop terminates all data transfers and ends processing - Stop() error + Stop(ctx context.Context) error // RegisterVoucherType registers a validator for the given voucher type // will error if voucher type does not implement voucher diff --git a/network/libp2p_impl.go b/network/libp2p_impl.go index ee304ba4..88b70a8a 100644 --- a/network/libp2p_impl.go +++ b/network/libp2p_impl.go @@ -37,51 +37,6 @@ type libp2pDataTransferNetwork struct { receiver Receiver } -type streamMessageSender struct { - s network.Stream -} - -func (s *streamMessageSender) Close() error { - return helpers.FullClose(s.s) -} - -func (s *streamMessageSender) Reset() error { - return s.s.Reset() -} - -func (s *streamMessageSender) SendMsg(ctx context.Context, msg datatransfer.Message) error { - return msgToStream(ctx, s.s, msg) -} - -func msgToStream(ctx context.Context, s network.Stream, msg datatransfer.Message) error { - if msg.IsRequest() { - log.Debugf("Outgoing request message for transfer ID: %d", msg.TransferID()) - } - - deadline := time.Now().Add(sendMessageTimeout) - if dl, ok := ctx.Deadline(); ok { - deadline = dl - } - if err := s.SetWriteDeadline(deadline); err != nil { - log.Warnf("error setting deadline: %s", err) - } - - switch s.Protocol() { - case ProtocolDataTransfer: - if err := msg.ToNet(s); err != nil { - log.Debugf("error: %s", err) - return err - } - default: - return fmt.Errorf("unrecognized protocol on remote: %s", s.Protocol()) - } - - if err := s.SetWriteDeadline(time.Time{}); err != nil { - log.Warnf("error resetting deadline: %s", err) - } - return nil -} - func (dtnet *libp2pDataTransferNetwork) newStreamToPeer(ctx context.Context, p peer.ID) (network.Stream, error) { return dtnet.host.NewStream(ctx, p, ProtocolDataTransfer) } @@ -167,3 +122,32 @@ func (dtnet *libp2pDataTransferNetwork) Protect(id peer.ID, tag string) { func (dtnet *libp2pDataTransferNetwork) Unprotect(id peer.ID, tag string) bool { return dtnet.host.ConnManager().Unprotect(id, tag) } + +func msgToStream(ctx context.Context, s network.Stream, msg datatransfer.Message) error { + if msg.IsRequest() { + log.Debugf("Outgoing request message for transfer ID: %d", msg.TransferID()) + } + + deadline := time.Now().Add(sendMessageTimeout) + if dl, ok := ctx.Deadline(); ok { + deadline = dl + } + if err := s.SetWriteDeadline(deadline); err != nil { + log.Warnf("error setting deadline: %s", err) + } + + switch s.Protocol() { + case ProtocolDataTransfer: + if err := msg.ToNet(s); err != nil { + log.Debugf("error: %s", err) + return err + } + default: + return fmt.Errorf("unrecognized protocol on remote: %s", s.Protocol()) + } + + if err := s.SetWriteDeadline(time.Time{}); err != nil { + log.Warnf("error resetting deadline: %s", err) + } + return nil +} diff --git a/types.go b/types.go index 2b12667d..22752a7b 100644 --- a/types.go +++ b/types.go @@ -10,12 +10,6 @@ import ( "github.com/filecoin-project/go-data-transfer/encoding" ) -type errorString string - -func (es errorString) Error() string { - return string(es) -} - //go:generate cbor-gen-for ChannelID // TypeIdentifier is a unique string identifier for a type of encodable object in a