diff --git a/cmd/ipfs/init.go b/cmd/ipfs/init.go index f75251b098e..2c4446c58bc 100644 --- a/cmd/ipfs/init.go +++ b/cmd/ipfs/init.go @@ -10,7 +10,7 @@ import ( "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/commander" config "github.com/jbenet/go-ipfs/config" ci "github.com/jbenet/go-ipfs/crypto" - spipe "github.com/jbenet/go-ipfs/crypto/spipe" + peer "github.com/jbenet/go-ipfs/peer" updates "github.com/jbenet/go-ipfs/updates" u "github.com/jbenet/go-ipfs/util" ) @@ -121,7 +121,7 @@ func initCmd(c *commander.Command, inp []string) error { } cfg.Identity.PrivKey = base64.StdEncoding.EncodeToString(skbytes) - id, err := spipe.IDFromPubKey(pk) + id, err := peer.IDFromPubKey(pk) if err != nil { return err } diff --git a/cmd/ipfs/ipfs.go b/cmd/ipfs/ipfs.go index fb5b265971c..2211985ec70 100644 --- a/cmd/ipfs/ipfs.go +++ b/cmd/ipfs/ipfs.go @@ -41,8 +41,7 @@ Advanced Commands: mount Mount an ipfs read-only mountpoint. serve Serve an interface to ipfs. - - net-diag Print network diagnostic + net-diag Print network diagnostic Use "ipfs help " for more information about a command. `, diff --git a/core/commands/diag.go b/core/commands/diag.go index c06499ec63b..fdb84ecf493 100644 --- a/core/commands/diag.go +++ b/core/commands/diag.go @@ -8,8 +8,22 @@ import ( "time" "github.com/jbenet/go-ipfs/core" + diagn "github.com/jbenet/go-ipfs/diagnostics" ) +func PrintDiagnostics(info []*diagn.DiagInfo, out io.Writer) { + for _, i := range info { + fmt.Fprintf(out, "Peer: %s\n", i.ID) + fmt.Fprintf(out, "\tUp for: %s\n", i.LifeSpan.String()) + fmt.Fprintf(out, "\tConnected To:\n") + for _, c := range i.Connections { + fmt.Fprintf(out, "\t%s\n\t\tLatency = %s\n", c.ID, c.Latency.String()) + } + fmt.Fprintln(out) + } + +} + func Diag(n *core.IpfsNode, args []string, opts map[string]interface{}, out io.Writer) error { if n.Diagnostics == nil { return errors.New("Cannot run diagnostic in offline mode!") @@ -29,15 +43,7 @@ func Diag(n *core.IpfsNode, args []string, opts map[string]interface{}, out io.W return err } } else { - for _, i := range info { - fmt.Fprintf(out, "Peer: %s\n", i.ID) - fmt.Fprintf(out, "\tUp for: %s\n", i.LifeSpan.String()) - fmt.Fprintf(out, "\tConnected To:\n") - for _, c := range i.Connections { - fmt.Fprintf(out, "\t%s\n\t\tLatency = %s\n", c.ID, c.Latency.String()) - } - fmt.Fprintln(out) - } + PrintDiagnostics(info, out) } return nil } diff --git a/core/core.go b/core/core.go index d22390d9296..331299fec8c 100644 --- a/core/core.go +++ b/core/core.go @@ -108,6 +108,7 @@ func NewIpfsNode(cfg *config.Config, online bool) (*IpfsNode, error) { route *dht.IpfsDHT exchangeSession exchange.Interface diagnostics *diag.Diagnostics + network inet.Network ) if online { @@ -135,11 +136,12 @@ func NewIpfsNode(cfg *config.Config, online bool) (*IpfsNode, error) { if err != nil { return nil, err } + network = net diagnostics = diag.NewDiagnostics(local, net, diagService) diagService.SetHandler(diagnostics) - route = dht.NewDHT(local, peerstore, net, dhtService, d) + route = dht.NewDHT(ctx, local, peerstore, net, dhtService, d) // TODO(brian): perform this inside NewDHT factory method dhtService.SetHandler(route) // wire the handler to the service. @@ -173,6 +175,7 @@ func NewIpfsNode(cfg *config.Config, online bool) (*IpfsNode, error) { Routing: route, Namesys: ns, Diagnostics: diagnostics, + Network: network, }, nil } diff --git a/crypto/key.go b/crypto/key.go index 4b40feb6dd9..b26d231ea9b 100644 --- a/crypto/key.go +++ b/crypto/key.go @@ -99,7 +99,7 @@ func GenerateEKeyPair(curveName string) ([]byte, GenSharedKey, error) { } pubKey := elliptic.Marshal(curve, x, y) - log.Debug("GenerateEKeyPair %d", len(pubKey)) + // log.Debug("GenerateEKeyPair %d", len(pubKey)) done := func(theirPub []byte) ([]byte, error) { // Verify and unpack node's public key. diff --git a/crypto/spipe/handshake.go b/crypto/spipe/handshake.go index ea06afbddb3..e5f3c94a9fd 100644 --- a/crypto/spipe/handshake.go +++ b/crypto/spipe/handshake.go @@ -18,6 +18,7 @@ import ( "hash" proto "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto" + ci "github.com/jbenet/go-ipfs/crypto" peer "github.com/jbenet/go-ipfs/peer" u "github.com/jbenet/go-ipfs/util" @@ -204,7 +205,7 @@ func (s *SecurePipe) handshake() error { } if bytes.Compare(resp2, finished) != 0 { - return errors.New("Negotiation failed.") + return fmt.Errorf("Negotiation failed, got: %s", resp2) } log.Debug("%s handshake: Got node id: %s", s.local, s.remote) @@ -229,7 +230,15 @@ func (s *SecurePipe) handleSecureIn(hashType string, tIV, tCKey, tMKey []byte) { theirMac, macSize := makeMac(hashType, tMKey) for { - data, ok := <-s.insecure.In + var data []byte + ok := true + + select { + case <-s.ctx.Done(): + ok = false // return out + case data, ok = <-s.insecure.In: + } + if !ok { close(s.Duplex.In) return @@ -266,8 +275,17 @@ func (s *SecurePipe) handleSecureOut(hashType string, mIV, mCKey, mMKey []byte) myMac, macSize := makeMac(hashType, mMKey) for { - data, ok := <-s.Out + var data []byte + ok := true + + select { + case <-s.ctx.Done(): + ok = false // return out + case data, ok = <-s.Out: + } + if !ok { + close(s.insecure.Out) return } @@ -288,16 +306,6 @@ func (s *SecurePipe) handleSecureOut(hashType string, mIV, mCKey, mMKey []byte) } } -// IDFromPubKey retrieves a Public Key from the peer given by pk -func IDFromPubKey(pk ci.PubKey) (peer.ID, error) { - b, err := pk.Bytes() - if err != nil { - return nil, err - } - hash := u.Hash(b) - return peer.ID(hash), nil -} - // Determines which algorithm to use. Note: f(a, b) = f(b, a) func selectBest(myPrefs, theirPrefs string) (string, error) { // Person with greatest hash gets first choice. @@ -334,7 +342,7 @@ func selectBest(myPrefs, theirPrefs string) (string, error) { // else, construct it. func getOrConstructPeer(peers peer.Peerstore, rpk ci.PubKey) (*peer.Peer, error) { - rid, err := IDFromPubKey(rpk) + rid, err := peer.IDFromPubKey(rpk) if err != nil { return nil, err } @@ -373,7 +381,8 @@ func getOrConstructPeer(peers peer.Peerstore, rpk ci.PubKey) (*peer.Peer, error) // this shouldn't ever happen, given we hashed, etc, but it could mean // expected code (or protocol) invariants violated. if !npeer.PubKey.Equals(rpk) { - return nil, fmt.Errorf("WARNING: PubKey mismatch: %v", npeer) + log.Error("WARNING: PubKey mismatch: %v", npeer) + panic("secure channel pubkey mismatch") } return npeer, nil } diff --git a/crypto/spipe/pipe.go b/crypto/spipe/pipe.go index 7f9ccc30f62..b1c56f1c17b 100644 --- a/crypto/spipe/pipe.go +++ b/crypto/spipe/pipe.go @@ -34,34 +34,29 @@ type params struct { // NewSecurePipe constructs a pipe with channels of a given buffer size. func NewSecurePipe(ctx context.Context, bufsize int, local *peer.Peer, - peers peer.Peerstore) (*SecurePipe, error) { + peers peer.Peerstore, insecure Duplex) (*SecurePipe, error) { + + ctx, cancel := context.WithCancel(ctx) sp := &SecurePipe{ Duplex: Duplex{ In: make(chan []byte, bufsize), Out: make(chan []byte, bufsize), }, - local: local, - peers: peers, - } - return sp, nil -} + local: local, + peers: peers, + insecure: insecure, -// Wrap creates a secure connection on top of an insecure duplex channel. -func (s *SecurePipe) Wrap(ctx context.Context, insecure Duplex) error { - if s.ctx != nil { - return errors.New("Pipe in use") + ctx: ctx, + cancel: cancel, } - s.insecure = insecure - s.ctx, s.cancel = context.WithCancel(ctx) - - if err := s.handshake(); err != nil { - s.cancel() - return err + if err := sp.handshake(); err != nil { + sp.Close() + return nil, err } - return nil + return sp, nil } // LocalPeer retrieves the local peer. @@ -76,11 +71,12 @@ func (s *SecurePipe) RemotePeer() *peer.Peer { // Close closes the secure pipe func (s *SecurePipe) Close() error { - if s.cancel == nil { - return errors.New("pipe already closed") + select { + case <-s.ctx.Done(): + return errors.New("already closed") + default: } s.cancel() - s.cancel = nil return nil } diff --git a/daemon/daemon_client.go b/daemon/daemon_client.go index 8db8615358f..4ed1be73cb1 100644 --- a/daemon/daemon_client.go +++ b/daemon/daemon_client.go @@ -47,26 +47,43 @@ func getDaemonAddr(confdir string) (string, error) { // over network RPC API. The address of the daemon is retrieved from the config // directory, where live daemons write their addresses to special files. func SendCommand(command *Command, confdir string) error { - //check if daemon is running - log.Info("Checking if daemon is running...") + server := os.Getenv("IPFS_ADDRESS_RPC") + + if server == "" { + //check if daemon is running + log.Info("Checking if daemon is running...") + if !serverIsRunning(confdir) { + return ErrDaemonNotRunning + } + + log.Info("Daemon is running!") + + var err error + server, err = getDaemonAddr(confdir) + if err != nil { + return err + } + } + + return serverComm(server, command) +} + +func serverIsRunning(confdir string) bool { var err error confdir, err = u.TildeExpansion(confdir) if err != nil { - return err + log.Error("Tilde Expansion Failed: %s", err) + return false } lk, err := daemonLock(confdir) if err == nil { lk.Close() - return ErrDaemonNotRunning - } - - log.Info("Daemon is running! [reason = %s]", err) - - server, err := getDaemonAddr(confdir) - if err != nil { - return err + return false } + return true +} +func serverComm(server string, command *Command) error { log.Info("Daemon address: %s", server) maddr, err := ma.NewMultiaddr(server) if err != nil { diff --git a/daemon/daemon_test.go b/daemon/daemon_test.go index ad65bfe2651..7fba742699f 100644 --- a/daemon/daemon_test.go +++ b/daemon/daemon_test.go @@ -9,7 +9,7 @@ import ( config "github.com/jbenet/go-ipfs/config" core "github.com/jbenet/go-ipfs/core" ci "github.com/jbenet/go-ipfs/crypto" - spipe "github.com/jbenet/go-ipfs/crypto/spipe" + peer "github.com/jbenet/go-ipfs/peer" ) func TestInitializeDaemonListener(t *testing.T) { @@ -23,7 +23,7 @@ func TestInitializeDaemonListener(t *testing.T) { t.Fatal(err) } - ident, _ := spipe.IDFromPubKey(pub) + ident, _ := peer.IDFromPubKey(pub) privKey := base64.StdEncoding.EncodeToString(prbytes) pID := ident.Pretty() diff --git a/diagnostics/diag.go b/diagnostics/diag.go index 8a6c636b6a5..f347c79ed6c 100644 --- a/diagnostics/diag.go +++ b/diagnostics/diag.go @@ -1,4 +1,4 @@ -package diagnostic +package diagnostics import ( "bytes" @@ -48,15 +48,17 @@ type connDiagInfo struct { ID string } -type diagInfo struct { +type DiagInfo struct { ID string Connections []connDiagInfo Keys []string LifeSpan time.Duration + BwIn uint64 + BwOut uint64 CodeVersion string } -func (di *diagInfo) Marshal() []byte { +func (di *DiagInfo) Marshal() []byte { b, err := json.Marshal(di) if err != nil { panic(err) @@ -69,12 +71,13 @@ func (d *Diagnostics) getPeers() []*peer.Peer { return d.network.GetPeerList() } -func (d *Diagnostics) getDiagInfo() *diagInfo { - di := new(diagInfo) +func (d *Diagnostics) getDiagInfo() *DiagInfo { + di := new(DiagInfo) di.CodeVersion = "github.com/jbenet/go-ipfs" di.ID = d.self.ID.Pretty() di.LifeSpan = time.Since(d.birth) di.Keys = nil // Currently no way to query datastore + di.BwIn, di.BwOut = d.network.GetBandwidthTotals() for _, p := range d.getPeers() { di.Connections = append(di.Connections, connDiagInfo{p.GetLatency(), p.ID.Pretty()}) @@ -88,7 +91,7 @@ func newID() string { return string(id) } -func (d *Diagnostics) GetDiagnostic(timeout time.Duration) ([]*diagInfo, error) { +func (d *Diagnostics) GetDiagnostic(timeout time.Duration) ([]*DiagInfo, error) { log.Debug("Getting diagnostic.") ctx, _ := context.WithTimeout(context.TODO(), timeout) @@ -102,7 +105,7 @@ func (d *Diagnostics) GetDiagnostic(timeout time.Duration) ([]*diagInfo, error) peers := d.getPeers() log.Debug("Sending diagnostic request to %d peers.", len(peers)) - var out []*diagInfo + var out []*DiagInfo di := d.getDiagInfo() out = append(out, di) @@ -134,15 +137,15 @@ func (d *Diagnostics) GetDiagnostic(timeout time.Duration) ([]*diagInfo, error) return out, nil } -func AppendDiagnostics(data []byte, cur []*diagInfo) []*diagInfo { +func AppendDiagnostics(data []byte, cur []*DiagInfo) []*DiagInfo { buf := bytes.NewBuffer(data) dec := json.NewDecoder(buf) for { - di := new(diagInfo) + di := new(DiagInfo) err := dec.Decode(di) if err != nil { if err != io.EOF { - log.Error("error decoding diagInfo: %v", err) + log.Error("error decoding DiagInfo: %v", err) } break } @@ -216,6 +219,7 @@ func (d *Diagnostics) handleDiagnostic(p *peer.Peer, pmes *Message) (*Message, e sendcount := 0 for _, p := range d.getPeers() { log.Debug("Sending diagnostic request to peer: %s", p) + sendcount++ go func(p *peer.Peer) { out, err := d.getDiagnosticFromPeer(ctx, p, pmes) if err != nil { diff --git a/diagnostics/message.pb.go b/diagnostics/message.pb.go index a3ef994efbb..30f2b58dd15 100644 --- a/diagnostics/message.pb.go +++ b/diagnostics/message.pb.go @@ -3,7 +3,7 @@ // DO NOT EDIT! /* -Package diagnostic is a generated protocol buffer package. +Package diagnostics is a generated protocol buffer package. It is generated from these files: message.proto @@ -11,7 +11,7 @@ It is generated from these files: It has these top-level messages: Message */ -package diagnostic +package diagnostics import proto "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto" import math "math" diff --git a/diagnostics/message.proto b/diagnostics/message.proto index 349afba257d..ca1e367f277 100644 --- a/diagnostics/message.proto +++ b/diagnostics/message.proto @@ -1,4 +1,4 @@ -package diagnostic; +package diagnostics; message Message { required string DiagID = 1; diff --git a/diagnostics/vis.go b/diagnostics/vis.go new file mode 100644 index 00000000000..def418e5c79 --- /dev/null +++ b/diagnostics/vis.go @@ -0,0 +1,56 @@ +package diagnostics + +import "encoding/json" + +type node struct { + Name string `json:"name"` + Value uint64 `json:"value"` +} + +type link struct { + Source int `json:"source"` + Target int `json:"target"` + Value int `json:"value"` +} + +func GetGraphJson(dinfo []*DiagInfo) []byte { + out := make(map[string]interface{}) + names := make(map[string]int) + var nodes []*node + for _, di := range dinfo { + names[di.ID] = len(nodes) + val := di.BwIn + di.BwOut + nodes = append(nodes, &node{Name: di.ID, Value: val}) + } + + var links []*link + linkexists := make([][]bool, len(nodes)) + for i, _ := range linkexists { + linkexists[i] = make([]bool, len(nodes)) + } + + for _, di := range dinfo { + myid := names[di.ID] + for _, con := range di.Connections { + thisid := names[con.ID] + if !linkexists[thisid][myid] { + links = append(links, &link{ + Source: myid, + Target: thisid, + Value: 3, + }) + linkexists[myid][thisid] = true + } + } + } + + out["nodes"] = nodes + out["links"] = links + + b, err := json.Marshal(out) + if err != nil { + panic(err) + } + + return b +} diff --git a/exchange/bitswap/bitswap.go b/exchange/bitswap/bitswap.go index 7eb8870aa50..b93b1a9b85e 100644 --- a/exchange/bitswap/bitswap.go +++ b/exchange/bitswap/bitswap.go @@ -24,13 +24,12 @@ func NetMessageSession(parent context.Context, p *peer.Peer, net inet.Network, srv inet.Service, directory bsnet.Routing, d ds.Datastore, nice bool) exchange.Interface { - networkAdapter := bsnet.NetMessageAdapter(srv, nil) + networkAdapter := bsnet.NetMessageAdapter(srv, net, nil) bs := &bitswap{ blockstore: blockstore.NewBlockstore(d), notifications: notifications.New(), strategy: strategy.New(nice), routing: directory, - network: net, sender: networkAdapter, wantlist: u.NewKeySet(), } @@ -42,9 +41,6 @@ func NetMessageSession(parent context.Context, p *peer.Peer, // bitswap instances implement the bitswap protocol. type bitswap struct { - // network maintains connections to the outside world. - network inet.Network - // sender delivers messages on behalf of the session sender bsnet.Adapter @@ -85,11 +81,20 @@ func (bs *bitswap) Block(parent context.Context, k u.Key) (*blocks.Block, error) message.AppendWanted(wanted) } message.AppendWanted(k) - for iiiii := range peersToQuery { - log.Debug("bitswap got peersToQuery: %s", iiiii) + for peerToQuery := range peersToQuery { + log.Debug("bitswap got peersToQuery: %s", peerToQuery) go func(p *peer.Peer) { + + log.Debug("bitswap dialing peer: %s", p) + err := bs.sender.DialPeer(p) + if err != nil { + log.Error("Error sender.DialPeer(%s)", p) + return + } + response, err := bs.sender.SendRequest(ctx, p, message) if err != nil { + log.Error("Error sender.SendRequest(%s)", p) return } // FIXME ensure accounting is handled correctly when @@ -101,7 +106,7 @@ func (bs *bitswap) Block(parent context.Context, k u.Key) (*blocks.Block, error) return } bs.ReceiveMessage(ctx, p, response) - }(iiiii) + }(peerToQuery) } }() diff --git a/exchange/bitswap/network/interface.go b/exchange/bitswap/network/interface.go index 8985ecefc30..03d7d341561 100644 --- a/exchange/bitswap/network/interface.go +++ b/exchange/bitswap/network/interface.go @@ -11,6 +11,9 @@ import ( // Adapter provides network connectivity for BitSwap sessions type Adapter interface { + // DialPeer ensures there is a connection to peer. + DialPeer(*peer.Peer) error + // SendMessage sends a BitSwap message to a peer. SendMessage( context.Context, diff --git a/exchange/bitswap/network/net_message_adapter.go b/exchange/bitswap/network/net_message_adapter.go index a95e566ccaa..52f42807688 100644 --- a/exchange/bitswap/network/net_message_adapter.go +++ b/exchange/bitswap/network/net_message_adapter.go @@ -10,9 +10,10 @@ import ( ) // NetMessageAdapter wraps a NetMessage network service -func NetMessageAdapter(s inet.Service, r Receiver) Adapter { +func NetMessageAdapter(s inet.Service, n inet.Network, r Receiver) Adapter { adapter := impl{ nms: s, + net: n, receiver: r, } s.SetHandler(&adapter) @@ -22,6 +23,7 @@ func NetMessageAdapter(s inet.Service, r Receiver) Adapter { // implements an Adapter that integrates with a NetMessage network service type impl struct { nms inet.Service + net inet.Network // inbound messages from the network are forwarded to the receiver receiver Receiver @@ -58,6 +60,10 @@ func (adapter *impl) HandleMessage( return outgoing } +func (adapter *impl) DialPeer(p *peer.Peer) error { + return adapter.net.DialPeer(p) +} + func (adapter *impl) SendMessage( ctx context.Context, p *peer.Peer, diff --git a/exchange/bitswap/testnet/network.go b/exchange/bitswap/testnet/network.go index 4d5f8c35ea4..c3081337df4 100644 --- a/exchange/bitswap/testnet/network.go +++ b/exchange/bitswap/testnet/network.go @@ -3,6 +3,7 @@ package bitswap import ( "bytes" "errors" + "fmt" context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" bsmsg "github.com/jbenet/go-ipfs/exchange/bitswap/message" @@ -14,6 +15,8 @@ import ( type Network interface { Adapter(*peer.Peer) bsnet.Adapter + HasPeer(*peer.Peer) bool + SendMessage( ctx context.Context, from *peer.Peer, @@ -49,6 +52,11 @@ func (n *network) Adapter(p *peer.Peer) bsnet.Adapter { return client } +func (n *network) HasPeer(p *peer.Peer) bool { + _, found := n.clients[p.Key()] + return found +} + // TODO should this be completely asynchronous? // TODO what does the network layer do with errors received from services? func (n *network) SendMessage( @@ -155,6 +163,14 @@ func (nc *networkClient) SendRequest( return nc.network.SendRequest(ctx, nc.local, to, message) } +func (nc *networkClient) DialPeer(p *peer.Peer) error { + // no need to do anything because dialing isn't a thing in this test net. + if !nc.network.HasPeer(p) { + return fmt.Errorf("Peer not in network: %s", p) + } + return nil +} + func (nc *networkClient) SetDelegate(r bsnet.Receiver) { nc.Receiver = r } diff --git a/msgproto/msgproto.go b/msgproto/msgproto.go deleted file mode 100644 index bdd9f1ed51d..00000000000 --- a/msgproto/msgproto.go +++ /dev/null @@ -1 +0,0 @@ -package msgproto diff --git a/net/conn/conn.go b/net/conn/conn.go index dcf6c923116..00d4a91e9d9 100644 --- a/net/conn/conn.go +++ b/net/conn/conn.go @@ -2,99 +2,152 @@ package conn import ( "fmt" + "time" + context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" msgio "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-msgio" ma "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr" manet "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr/net" - spipe "github.com/jbenet/go-ipfs/crypto/spipe" peer "github.com/jbenet/go-ipfs/peer" u "github.com/jbenet/go-ipfs/util" + ctxc "github.com/jbenet/go-ipfs/util/ctxcloser" ) var log = u.Logger("conn") -// ChanBuffer is the size of the buffer in the Conn Chan -const ChanBuffer = 10 +const ( + // ChanBuffer is the size of the buffer in the Conn Chan + ChanBuffer = 10 -// 1 MB -const MaxMessageSize = 1 << 20 + // MaxMessageSize is the size of the largest single message + MaxMessageSize = 1 << 20 // 1 MB -// Conn represents a connection to another Peer (IPFS Node). -type Conn struct { - Peer *peer.Peer - Addr ma.Multiaddr - Conn manet.Conn + // HandshakeTimeout for when nodes first connect + HandshakeTimeout = time.Second * 5 +) + +// msgioPipe is a pipe using msgio channels. +type msgioPipe struct { + outgoing *msgio.Chan + incoming *msgio.Chan +} + +func newMsgioPipe(size int) *msgioPipe { + return &msgioPipe{ + outgoing: msgio.NewChan(10), + incoming: msgio.NewChan(10), + } +} - Closed chan bool - Outgoing *msgio.Chan - Incoming *msgio.Chan - Secure *spipe.SecurePipe +// singleConn represents a single connection to another Peer (IPFS Node). +type singleConn struct { + local *peer.Peer + remote *peer.Peer + maconn manet.Conn + msgio *msgioPipe + + ctxc.ContextCloser } -// Map maps Keys (Peer.IDs) to Connections. -type Map map[u.Key]*Conn +// newConn constructs a new connection +func newSingleConn(ctx context.Context, local, remote *peer.Peer, + maconn manet.Conn) (Conn, error) { -// NewConn constructs a new connection -func NewConn(peer *peer.Peer, addr ma.Multiaddr, mconn manet.Conn) (*Conn, error) { - conn := &Conn{ - Peer: peer, - Addr: addr, - Conn: mconn, + conn := &singleConn{ + local: local, + remote: remote, + maconn: maconn, + msgio: newMsgioPipe(10), } - if err := conn.newChans(); err != nil { - return nil, err + conn.ContextCloser = ctxc.NewContextCloser(ctx, conn.close) + + log.Info("newSingleConn: %v to %v", local, remote) + + // setup the various io goroutines + go func() { + conn.Children().Add(1) + conn.msgio.outgoing.WriteTo(maconn) + conn.Children().Done() + }() + go func() { + conn.Children().Add(1) + conn.msgio.incoming.ReadFrom(maconn, MaxMessageSize) + conn.Children().Done() + }() + + // version handshake + ctxT, _ := context.WithTimeout(ctx, HandshakeTimeout) + if err := VersionHandshake(ctxT, conn); err != nil { + conn.Close() + return nil, fmt.Errorf("Version handshake: %s", err) } return conn, nil } -// Dial connects to a particular peer, over a given network -// Example: Dial("udp", peer) -func Dial(network string, peer *peer.Peer) (*Conn, error) { - addr := peer.NetAddress(network) - if addr == nil { - return nil, fmt.Errorf("No address for network %s", network) - } +// close is the internal close function, called by ContextCloser.Close +func (c *singleConn) close() error { + log.Debug("%s closing Conn with %s", c.local, c.remote) - nconn, err := manet.Dial(addr) - if err != nil { - return nil, err - } + // close underlying connection + err := c.maconn.Close() + c.msgio.outgoing.Close() + return err +} - return NewConn(peer, addr, nconn) +// ID is an identifier unique to this connection. +func (c *singleConn) ID() string { + return ID(c) } -// Construct new channels for given Conn. -func (c *Conn) newChans() error { - if c.Outgoing != nil || c.Incoming != nil { - return fmt.Errorf("Conn already initialized") - } +func (c *singleConn) String() string { + return String(c, "singleConn") +} - c.Outgoing = msgio.NewChan(10) - c.Incoming = msgio.NewChan(10) - c.Closed = make(chan bool, 1) +// LocalMultiaddr is the Multiaddr on this side +func (c *singleConn) LocalMultiaddr() ma.Multiaddr { + return c.maconn.LocalMultiaddr() +} - go c.Outgoing.WriteTo(c.Conn) - go c.Incoming.ReadFrom(c.Conn, MaxMessageSize) +// RemoteMultiaddr is the Multiaddr on the remote side +func (c *singleConn) RemoteMultiaddr() ma.Multiaddr { + return c.maconn.RemoteMultiaddr() +} - return nil +// LocalPeer is the Peer on this side +func (c *singleConn) LocalPeer() *peer.Peer { + return c.local } -// Close closes the connection, and associated channels. -func (c *Conn) Close() error { - log.Debug("Closing Conn with %v", c.Peer) - if c.Conn == nil { - return fmt.Errorf("Already closed") // already closed - } +// RemotePeer is the Peer on the remote side +func (c *singleConn) RemotePeer() *peer.Peer { + return c.remote +} - // closing net connection - err := c.Conn.Close() - c.Conn = nil - // closing channels - c.Incoming.Close() - c.Outgoing.Close() - c.Closed <- true - return err +// In returns a readable message channel +func (c *singleConn) In() <-chan []byte { + return c.msgio.incoming.MsgChan +} + +// Out returns a writable message channel +func (c *singleConn) Out() chan<- []byte { + return c.msgio.outgoing.MsgChan +} + +// ID returns the ID of a given Conn. +func ID(c Conn) string { + l := fmt.Sprintf("%s/%s", c.LocalMultiaddr(), c.LocalPeer().ID) + r := fmt.Sprintf("%s/%s", c.RemoteMultiaddr(), c.RemotePeer().ID) + lh := u.Hash([]byte(l)) + rh := u.Hash([]byte(r)) + ch := u.XOR(lh, rh) + return u.Key(ch).Pretty() +} + +// String returns the user-friendly String representation of a conn +func String(c Conn, typ string) string { + return fmt.Sprintf("%s (%s) <-- %s --> (%s) %s", + c.LocalPeer(), c.LocalMultiaddr(), typ, c.RemoteMultiaddr(), c.RemotePeer()) } diff --git a/net/conn/conn_test.go b/net/conn/conn_test.go index 95d5833dfa2..803b517a721 100644 --- a/net/conn/conn_test.go +++ b/net/conn/conn_test.go @@ -1,95 +1,141 @@ package conn import ( + "bytes" + "fmt" + "os" + "runtime" + "strconv" + "sync" "testing" + "time" - peer "github.com/jbenet/go-ipfs/peer" - - ma "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr" - manet "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr/net" - mh "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multihash" + context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" ) -func setupPeer(id string, addr string) (*peer.Peer, error) { - tcp, err := ma.NewMultiaddr(addr) - if err != nil { - return nil, err - } +func TestClose(t *testing.T) { + // t.Skip("Skipping in favor of another test") - mh, err := mh.FromHexString(id) - if err != nil { - return nil, err + ctx, cancel := context.WithCancel(context.Background()) + c1, c2 := setupConn(t, ctx, "/ip4/127.0.0.1/tcp/5534", "/ip4/127.0.0.1/tcp/5545") + + select { + case <-c1.Closed(): + t.Fatal("done before close") + case <-c2.Closed(): + t.Fatal("done before close") + default: } - p := &peer.Peer{ID: peer.ID(mh)} - p.AddAddress(tcp) - return p, nil -} + c1.Close() -func echoListen(listener manet.Listener) { - for { - c, err := listener.Accept() - if err == nil { - // fmt.Println("accepeted") - go echo(c) - } + select { + case <-c1.Closed(): + default: + t.Fatal("not done after cancel") } -} -func echo(c manet.Conn) { - for { - data := make([]byte, 1024) - i, err := c.Read(data) - if err != nil { - // fmt.Printf("error %v\n", err) - return - } - _, err = c.Write(data[:i]) - if err != nil { - // fmt.Printf("error %v\n", err) - return - } - // fmt.Println("echoing", data[:i]) + c2.Close() + + select { + case <-c2.Closed(): + default: + t.Fatal("not done after cancel") } + + cancel() // close the listener :P } -func TestDial(t *testing.T) { +func TestCancel(t *testing.T) { + // t.Skip("Skipping in favor of another test") - maddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/1234") - if err != nil { - t.Fatal("failure to parse multiaddr") + ctx, cancel := context.WithCancel(context.Background()) + c1, c2 := setupConn(t, ctx, "/ip4/127.0.0.1/tcp/5534", "/ip4/127.0.0.1/tcp/5545") + + select { + case <-c1.Closed(): + t.Fatal("done before close") + case <-c2.Closed(): + t.Fatal("done before close") + default: } - listener, err := manet.Listen(maddr) - if err != nil { - t.Fatal("error setting up listener", err) + + c1.Close() + c2.Close() + cancel() // listener + + // wait to ensure other goroutines run and close things. + <-time.After(time.Microsecond * 10) + // test that cancel called Close. + + select { + case <-c1.Closed(): + default: + t.Fatal("not done after cancel") } - go echoListen(listener) - p, err := setupPeer("11140beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33", "/ip4/127.0.0.1/tcp/1234") - if err != nil { - t.Fatal("error setting up peer", err) + select { + case <-c2.Closed(): + default: + t.Fatal("not done after cancel") } - c, err := Dial("tcp", p) - if err != nil { - t.Fatal("error dialing peer", err) +} + +func TestCloseLeak(t *testing.T) { + // t.Skip("Skipping in favor of another test") + + if os.Getenv("TRAVIS") == "true" { + t.Skip("this doesn't work well on travis") } - // fmt.Println("sending") - c.Outgoing.MsgChan <- []byte("beep") - c.Outgoing.MsgChan <- []byte("boop") - out := <-c.Incoming.MsgChan - // fmt.Println("recving", string(out)) - if string(out) != "beep" { - t.Error("unexpected conn output") + var wg sync.WaitGroup + + runPair := func(p1, p2, num int) { + a1 := strconv.Itoa(p1) + a2 := strconv.Itoa(p2) + ctx, cancel := context.WithCancel(context.Background()) + c1, c2 := setupConn(t, ctx, "/ip4/127.0.0.1/tcp/"+a1, "/ip4/127.0.0.1/tcp/"+a2) + + for i := 0; i < num; i++ { + b1 := []byte("beep") + c1.Out() <- b1 + b2 := <-c2.In() + if !bytes.Equal(b1, b2) { + panic("bytes not equal") + } + + b2 = []byte("boop") + c2.Out() <- b2 + b1 = <-c1.In() + if !bytes.Equal(b1, b2) { + panic("bytes not equal") + } + + <-time.After(time.Microsecond * 5) + } + + c1.Close() + c2.Close() + cancel() // close the listener + wg.Done() } - out = <-c.Incoming.MsgChan - if string(out) != "boop" { - t.Error("unexpected conn output") + var cons = 20 + var msgs = 100 + fmt.Printf("Running %d connections * %d msgs.\n", cons, msgs) + for i := 0; i < cons; i++ { + wg.Add(1) + go runPair(2000+i, 2001+i, msgs) } - // fmt.Println("closing") - c.Close() - listener.Close() + fmt.Printf("Waiting...\n") + wg.Wait() + // done! + + <-time.After(time.Millisecond * 150) + if runtime.NumGoroutine() > 20 { + // panic("uncomment me to debug") + t.Fatal("leaking goroutines:", runtime.NumGoroutine()) + } } diff --git a/net/conn/dial.go b/net/conn/dial.go new file mode 100644 index 00000000000..7bf85b9138f --- /dev/null +++ b/net/conn/dial.go @@ -0,0 +1,46 @@ +package conn + +import ( + "fmt" + + context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" + + manet "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr/net" + + peer "github.com/jbenet/go-ipfs/peer" +) + +// Dial connects to a particular peer, over a given network +// Example: d.Dial(ctx, "udp", peer) +func (d *Dialer) Dial(ctx context.Context, network string, remote *peer.Peer) (Conn, error) { + laddr := d.LocalPeer.NetAddress(network) + if laddr == nil { + return nil, fmt.Errorf("No local address for network %s", network) + } + + raddr := remote.NetAddress(network) + if raddr == nil { + return nil, fmt.Errorf("No remote address for network %s", network) + } + + // TODO: try to get reusing addr/ports to work. + // madialer := manet.Dialer{LocalAddr: laddr} + madialer := manet.Dialer{} + + log.Info("%s dialing %s %s", d.LocalPeer, remote, raddr) + maconn, err := madialer.Dial(raddr) + if err != nil { + return nil, err + } + + if err := d.Peerstore.Put(remote); err != nil { + log.Error("Error putting peer into peerstore: %s", remote) + } + + c, err := newSingleConn(ctx, d.LocalPeer, remote, maconn) + if err != nil { + return nil, err + } + + return newSecureConn(ctx, c, d.Peerstore) +} diff --git a/net/conn/dial_test.go b/net/conn/dial_test.go new file mode 100644 index 00000000000..e89f310402f --- /dev/null +++ b/net/conn/dial_test.go @@ -0,0 +1,153 @@ +package conn + +import ( + "testing" + + ci "github.com/jbenet/go-ipfs/crypto" + peer "github.com/jbenet/go-ipfs/peer" + + context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" + ma "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr" +) + +func setupPeer(addr string) (*peer.Peer, error) { + tcp, err := ma.NewMultiaddr(addr) + if err != nil { + return nil, err + } + + sk, pk, err := ci.GenerateKeyPair(ci.RSA, 512) + if err != nil { + return nil, err + } + + id, err := peer.IDFromPubKey(pk) + if err != nil { + return nil, err + } + + p := &peer.Peer{ID: id} + p.PrivKey = sk + p.PubKey = pk + p.AddAddress(tcp) + return p, nil +} + +func echoListen(ctx context.Context, listener Listener) { + for { + select { + case <-ctx.Done(): + return + case c := <-listener.Accept(): + go echo(ctx, c) + } + } +} + +func echo(ctx context.Context, c Conn) { + for { + select { + case <-ctx.Done(): + return + case m := <-c.In(): + c.Out() <- m + } + } +} + +func setupConn(t *testing.T, ctx context.Context, a1, a2 string) (a, b Conn) { + + p1, err := setupPeer(a1) + if err != nil { + t.Fatal("error setting up peer", err) + } + + p2, err := setupPeer(a2) + if err != nil { + t.Fatal("error setting up peer", err) + } + + laddr := p1.NetAddress("tcp") + if laddr == nil { + t.Fatal("Listen address is nil.") + } + + l1, err := Listen(ctx, laddr, p1, peer.NewPeerstore()) + if err != nil { + t.Fatal(err) + } + + d2 := &Dialer{ + Peerstore: peer.NewPeerstore(), + LocalPeer: p2, + } + + c2, err := d2.Dial(ctx, "tcp", p1) + if err != nil { + t.Fatal("error dialing peer", err) + } + + c1 := <-l1.Accept() + + return c1, c2 +} + +func TestDialer(t *testing.T) { + // t.Skip("Skipping in favor of another test") + + p1, err := setupPeer("/ip4/127.0.0.1/tcp/4234") + if err != nil { + t.Fatal("error setting up peer", err) + } + + p2, err := setupPeer("/ip4/127.0.0.1/tcp/4235") + if err != nil { + t.Fatal("error setting up peer", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + + laddr := p1.NetAddress("tcp") + if laddr == nil { + t.Fatal("Listen address is nil.") + } + + l, err := Listen(ctx, laddr, p1, peer.NewPeerstore()) + if err != nil { + t.Fatal(err) + } + + go echoListen(ctx, l) + + d := &Dialer{ + Peerstore: peer.NewPeerstore(), + LocalPeer: p2, + } + + c, err := d.Dial(ctx, "tcp", p1) + if err != nil { + t.Fatal("error dialing peer", err) + } + + // fmt.Println("sending") + c.Out() <- []byte("beep") + c.Out() <- []byte("boop") + + out := <-c.In() + // fmt.Println("recving", string(out)) + data := string(out) + if data != "beep" { + t.Error("unexpected conn output", data) + } + + out = <-c.In() + data = string(out) + if string(out) != "boop" { + t.Error("unexpected conn output", data) + } + + // fmt.Println("closing") + c.Close() + l.Close() + cancel() +} diff --git a/net/conn/handshake.go b/net/conn/handshake.go new file mode 100644 index 00000000000..633c8d5f7dc --- /dev/null +++ b/net/conn/handshake.go @@ -0,0 +1,58 @@ +package conn + +import ( + "errors" + "fmt" + + handshake "github.com/jbenet/go-ipfs/net/handshake" + + context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" + proto "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto" +) + +// VersionHandshake exchanges local and remote versions and compares them +// closes remote and returns an error in case of major difference +func VersionHandshake(ctx context.Context, c Conn) error { + rpeer := c.RemotePeer() + lpeer := c.LocalPeer() + + var remoteH, localH *handshake.Handshake1 + localH = handshake.CurrentHandshake() + + myVerBytes, err := proto.Marshal(localH) + if err != nil { + return err + } + + c.Out() <- myVerBytes + log.Debug("Sent my version (%s) to %s", localH, rpeer) + + select { + case <-ctx.Done(): + return ctx.Err() + + case <-c.Closing(): + return errors.New("remote closed connection during version exchange") + + case data, ok := <-c.In(): + if !ok { + return fmt.Errorf("error retrieving from conn: %v", rpeer) + } + + remoteH = new(handshake.Handshake1) + err = proto.Unmarshal(data, remoteH) + if err != nil { + return fmt.Errorf("could not decode remote version: %q", err) + } + + log.Debug("Received remote version (%s) from %s", remoteH, rpeer) + } + + if err := handshake.Compatible(localH, remoteH); err != nil { + log.Info("%s (%s) incompatible version with %s (%s)", lpeer, localH, rpeer, remoteH) + return err + } + + log.Debug("%s version handshake compatible %s", lpeer, rpeer) + return nil +} diff --git a/net/conn/interface.go b/net/conn/interface.go new file mode 100644 index 00000000000..5cfd8336db3 --- /dev/null +++ b/net/conn/interface.go @@ -0,0 +1,78 @@ +package conn + +import ( + peer "github.com/jbenet/go-ipfs/peer" + u "github.com/jbenet/go-ipfs/util" + ctxc "github.com/jbenet/go-ipfs/util/ctxcloser" + + ma "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr" +) + +// Map maps Keys (Peer.IDs) to Connections. +type Map map[u.Key]Conn + +// Conn is a generic message-based Peer-to-Peer connection. +type Conn interface { + // implement ContextCloser too! + ctxc.ContextCloser + + // ID is an identifier unique to this connection. + ID() string + + // LocalMultiaddr is the Multiaddr on this side + LocalMultiaddr() ma.Multiaddr + + // LocalPeer is the Peer on this side + LocalPeer() *peer.Peer + + // RemoteMultiaddr is the Multiaddr on the remote side + RemoteMultiaddr() ma.Multiaddr + + // RemotePeer is the Peer on the remote side + RemotePeer() *peer.Peer + + // In returns a readable message channel + In() <-chan []byte + + // Out returns a writable message channel + Out() chan<- []byte + + // Close ends the connection + // Close() error -- already in ContextCloser +} + +// Dialer is an object that can open connections. We could have a "convenience" +// Dial function as before, but it would have many arguments, as dialing is +// no longer simple (need a peerstore, a local peer, a context, a network, etc) +type Dialer struct { + + // LocalPeer is the identity of the local Peer. + LocalPeer *peer.Peer + + // Peerstore is the set of peers we know about locally. The Dialer needs it + // because when an incoming connection is identified, we should reuse the + // same peer objects (otherwise things get inconsistent). + Peerstore peer.Peerstore +} + +// Listener is an object that can accept connections. It matches net.Listener +type Listener interface { + + // Accept waits for and returns the next connection to the listener. + Accept() <-chan Conn + + // Multiaddr is the identity of the local Peer. + Multiaddr() ma.Multiaddr + + // LocalPeer is the identity of the local Peer. + LocalPeer() *peer.Peer + + // Peerstore is the set of peers we know about locally. The Listener needs it + // because when an incoming connection is identified, we should reuse the + // same peer objects (otherwise things get inconsistent). + Peerstore() peer.Peerstore + + // Close closes the listener. + // Any blocked Accept operations will be unblocked and return errors. + Close() error +} diff --git a/net/conn/listen.go b/net/conn/listen.go new file mode 100644 index 00000000000..20cfbb4fbea --- /dev/null +++ b/net/conn/listen.go @@ -0,0 +1,149 @@ +package conn + +import ( + context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" + ma "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr" + manet "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr/net" + + peer "github.com/jbenet/go-ipfs/peer" + ctxc "github.com/jbenet/go-ipfs/util/ctxcloser" +) + +// listener is an object that can accept connections. It implements Listener +type listener struct { + manet.Listener + + // chansize is the size of the internal channels for concurrency + chansize int + + // channel of incoming conections + conns chan Conn + + // Local multiaddr to listen on + maddr ma.Multiaddr + + // LocalPeer is the identity of the local Peer. + local *peer.Peer + + // Peerstore is the set of peers we know about locally + peers peer.Peerstore + + // Context for children Conn + ctx context.Context + + // embedded ContextCloser + ctxc.ContextCloser +} + +// disambiguate +func (l *listener) Close() error { + return l.ContextCloser.Close() +} + +// close called by ContextCloser.Close +func (l *listener) close() error { + log.Info("listener closing: %s %s", l.local, l.maddr) + return l.Listener.Close() +} + +func (l *listener) listen() { + l.Children().Add(1) + defer l.Children().Done() + + // handle at most chansize concurrent handshakes + sem := make(chan struct{}, l.chansize) + + // handle is a goroutine work function that handles the handshake. + // it's here only so that accepting new connections can happen quickly. + handle := func(maconn manet.Conn) { + defer func() { <-sem }() // release + + c, err := newSingleConn(l.ctx, l.local, nil, maconn) + if err != nil { + log.Error("Error accepting connection: %v", err) + return + } + + sc, err := newSecureConn(l.ctx, c, l.peers) + if err != nil { + log.Error("Error securing connection: %v", err) + return + } + + l.conns <- sc + } + + for { + maconn, err := l.Listener.Accept() + if err != nil { + + // if closing, we should exit. + select { + case <-l.Closing(): + return // done. + default: + } + + log.Error("Failed to accept connection: %v", err) + continue + } + + sem <- struct{}{} // acquire + go handle(maconn) + } +} + +// Accept waits for and returns the next connection to the listener. +// Note that unfortunately this +func (l *listener) Accept() <-chan Conn { + return l.conns +} + +// Multiaddr is the identity of the local Peer. +func (l *listener) Multiaddr() ma.Multiaddr { + return l.maddr +} + +// LocalPeer is the identity of the local Peer. +func (l *listener) LocalPeer() *peer.Peer { + return l.local +} + +// Peerstore is the set of peers we know about locally. The Listener needs it +// because when an incoming connection is identified, we should reuse the +// same peer objects (otherwise things get inconsistent). +func (l *listener) Peerstore() peer.Peerstore { + return l.peers +} + +// Listen listens on the particular multiaddr, with given peer and peerstore. +func Listen(ctx context.Context, addr ma.Multiaddr, local *peer.Peer, peers peer.Peerstore) (Listener, error) { + + ml, err := manet.Listen(addr) + if err != nil { + return nil, err + } + + // todo make this a variable + chansize := 10 + + l := &listener{ + Listener: ml, + maddr: addr, + peers: peers, + local: local, + conns: make(chan Conn, chansize), + chansize: chansize, + ctx: ctx, + } + + // need a separate context to use for the context closer. + // This is because the parent context will be given to all connections too, + // and if we close the listener, the connections shouldn't share the fate. + ctx2, _ := context.WithCancel(ctx) + l.ContextCloser = ctxc.NewContextCloser(ctx2, l.close) + + go l.listen() + + return l, nil +} diff --git a/net/conn/multiconn.go b/net/conn/multiconn.go new file mode 100644 index 00000000000..24b4cc99478 --- /dev/null +++ b/net/conn/multiconn.go @@ -0,0 +1,289 @@ +package conn + +import ( + "sync" + + context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" + ma "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr" + + peer "github.com/jbenet/go-ipfs/peer" + u "github.com/jbenet/go-ipfs/util" + ctxc "github.com/jbenet/go-ipfs/util/ctxcloser" +) + +// MultiConnMap is for shorthand +type MultiConnMap map[u.Key]*MultiConn + +// Duplex is a simple duplex channel +type Duplex struct { + In chan []byte + Out chan []byte +} + +// MultiConn represents a single connection to another Peer (IPFS Node). +type MultiConn struct { + + // connections, mapped by a string, which uniquely identifies the connection. + // this string is: /addr1/peer1/addr2/peer2 (peers ordered lexicographically) + conns map[string]Conn + + local *peer.Peer + remote *peer.Peer + + // fan-in/fan-out + duplex Duplex + + // for adding/removing connections concurrently + sync.RWMutex + ctxc.ContextCloser +} + +// NewMultiConn constructs a new connection +func NewMultiConn(ctx context.Context, local, remote *peer.Peer, conns []Conn) (*MultiConn, error) { + + c := &MultiConn{ + local: local, + remote: remote, + conns: map[string]Conn{}, + duplex: Duplex{ + In: make(chan []byte, 10), + Out: make(chan []byte, 10), + }, + } + + // must happen before Adds / fanOut + c.ContextCloser = ctxc.NewContextCloser(ctx, c.close) + + if conns != nil && len(conns) > 0 { + c.Add(conns...) + } + go c.fanOut() + return c, nil +} + +// Add adds given Conn instances to multiconn. +func (c *MultiConn) Add(conns ...Conn) { + c.Lock() + defer c.Unlock() + + for _, c2 := range conns { + log.Info("MultiConn: adding %s", c2) + if c.LocalPeer() != c2.LocalPeer() || c.RemotePeer() != c2.RemotePeer() { + log.Error("%s", c2) + c.Unlock() // ok to unlock (to log). panicing. + log.Error("%s", c) + c.Lock() // gotta relock to avoid lock panic from deferring. + panic("connection addresses mismatch") + } + + c.conns[c2.ID()] = c2 + go c.fanInSingle(c2) + log.Info("MultiConn: added %s", c2) + } +} + +// Remove removes given Conn instances from multiconn. +func (c *MultiConn) Remove(conns ...Conn) { + + // first remove them to avoid sending any more messages through it. + { + c.Lock() + for _, c1 := range conns { + c2, found := c.conns[c1.ID()] + if !found { + panic("Conn not in MultiConn") + } + if c1 != c2 { + panic("different Conn objects for same id.") + } + + delete(c.conns, c2.ID()) + } + c.Unlock() + } + + // close all in parallel, but wait for all to be done closing. + CloseConns(conns...) +} + +// CloseConns closes multiple connections in parallel, and waits for all +// to finish closing. +func CloseConns(conns ...Conn) { + var wg sync.WaitGroup + for _, child := range conns { + + select { + case <-child.Closed(): // if already closed, continue + continue + default: + } + + wg.Add(1) + go func(child Conn) { + child.Close() + wg.Done() + }(child) + } + wg.Wait() +} + +// fanOut is the multiplexor out -- it sends outgoing messages over the +// underlying single connections. +func (c *MultiConn) fanOut() { + c.Children().Add(1) + defer c.Children().Done() + + i := 0 + for { + select { + case <-c.Closing(): + return + + // send data out through our "best connection" + case m, more := <-c.duplex.Out: + if !more { + log.Info("%s out channel closed", c) + return + } + sc := c.BestConn() + if sc == nil { + // maybe this should be a logged error, not a panic. + panic("sending out multiconn without any live connection") + } + + i++ + log.Info("%s sending (%d)", sc, i) + sc.Out() <- m + } + } +} + +// fanInSingle is a multiplexor in -- it receives incoming messages over the +// underlying single connections. +func (c *MultiConn) fanInSingle(child Conn) { + c.Children().Add(1) + child.Children().Add(1) // yep, on the child too. + + // cleanup all data associated with this child Connection. + defer func() { + log.Info("closing: %s", child) + + // in case it still is in the map, remove it. + c.Lock() + delete(c.conns, child.ID()) + connLen := len(c.conns) + c.Unlock() + + c.Children().Done() + child.Children().Done() + + if connLen == 0 { + c.Close() // close self if all underlying children are gone? + } + }() + + i := 0 + for { + select { + case <-c.Closing(): // multiconn closing + return + + case <-child.Closing(): // child closing + return + + case m, more := <-child.In(): // receiving data + if !more { + log.Info("%s in channel closed", child) + return // closed + } + i++ + log.Info("%s received (%d)", child, i) + c.duplex.In <- m + } + } +} + +// close is the internal close function, called by ContextCloser.Close +func (c *MultiConn) close() error { + log.Debug("%s closing Conn with %s", c.local, c.remote) + + // get connections + c.RLock() + conns := make([]Conn, 0, len(c.conns)) + for _, c := range c.conns { + conns = append(conns, c) + } + c.RUnlock() + + // close underlying connections + CloseConns(conns...) + return nil +} + +// BestConn is the best connection in this MultiConn +func (c *MultiConn) BestConn() Conn { + c.RLock() + defer c.RUnlock() + + var id1 string + var c1 Conn + for id2, c2 := range c.conns { + if id1 == "" || id1 < id2 { + id1 = id2 + c1 = c2 + } + } + return c1 +} + +// ID is an identifier unique to this connection. +// In MultiConn, this is all the children IDs XORed together. +func (c *MultiConn) ID() string { + c.RLock() + defer c.RUnlock() + + ids := []byte(nil) + for i := range c.conns { + if ids == nil { + ids = []byte(i) + } else { + ids = u.XOR(ids, []byte(i)) + } + } + + return string(ids) +} + +func (c *MultiConn) String() string { + return String(c, "MultiConn") +} + +// LocalMultiaddr is the Multiaddr on this side +func (c *MultiConn) LocalMultiaddr() ma.Multiaddr { + return c.BestConn().LocalMultiaddr() +} + +// RemoteMultiaddr is the Multiaddr on the remote side +func (c *MultiConn) RemoteMultiaddr() ma.Multiaddr { + return c.BestConn().RemoteMultiaddr() +} + +// LocalPeer is the Peer on this side +func (c *MultiConn) LocalPeer() *peer.Peer { + return c.local +} + +// RemotePeer is the Peer on the remote side +func (c *MultiConn) RemotePeer() *peer.Peer { + return c.remote +} + +// In returns a readable message channel +func (c *MultiConn) In() <-chan []byte { + return c.duplex.In +} + +// Out returns a writable message channel +func (c *MultiConn) Out() chan<- []byte { + return c.duplex.Out +} diff --git a/net/conn/multiconn_test.go b/net/conn/multiconn_test.go new file mode 100644 index 00000000000..bb8404a135b --- /dev/null +++ b/net/conn/multiconn_test.go @@ -0,0 +1,324 @@ +package conn + +import ( + "fmt" + "sync" + "testing" + "time" + + peer "github.com/jbenet/go-ipfs/peer" + + context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" + ma "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr" +) + +func tcpAddr(t *testing.T, port int) ma.Multiaddr { + tcp, err := ma.NewMultiaddr(tcpAddrString(port)) + if err != nil { + t.Fatal(err) + } + return tcp +} + +func tcpAddrString(port int) string { + return fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", port) +} + +type msg struct { + sent bool + received bool + payload string +} + +func (m *msg) Sent(t *testing.T) { + if m.sent { + t.Fatal("sent msg at incorrect state:", m) + } + m.sent = true +} + +func (m *msg) Received(t *testing.T) { + if m.received { + t.Fatal("received msg at incorrect state:", m) + } + m.received = true +} + +type msgMap struct { + sent int + recv int + msgs map[string]*msg +} + +func (mm *msgMap) Sent(t *testing.T, payload string) { + mm.msgs[payload].Sent(t) + mm.sent++ +} + +func (mm *msgMap) Received(t *testing.T, payload string) { + mm.msgs[payload].Received(t) + mm.recv++ +} + +func (mm *msgMap) CheckDone(t *testing.T) { + if mm.sent != len(mm.msgs) { + t.Fatal("failed to send all msgs", mm.sent, len(mm.msgs)) + } + + if mm.sent != len(mm.msgs) { + t.Fatal("failed to send all msgs", mm.sent, len(mm.msgs)) + } +} + +func genMessages(num int, tag string) *msgMap { + msgs := &msgMap{msgs: map[string]*msg{}} + for i := 0; i < num; i++ { + s := fmt.Sprintf("Message #%d -- %s", i, tag) + msgs.msgs[s] = &msg{payload: s} + } + return msgs +} + +func setupMultiConns(t *testing.T, ctx context.Context) (a, b *MultiConn) { + + log.Info("Setting up peers") + p1, err := setupPeer(tcpAddrString(11000)) + if err != nil { + t.Fatal("error setting up peer", err) + } + + p2, err := setupPeer(tcpAddrString(12000)) + if err != nil { + t.Fatal("error setting up peer", err) + } + + // peerstores + p1ps := peer.NewPeerstore() + p2ps := peer.NewPeerstore() + + // listeners + listen := func(addr ma.Multiaddr, p *peer.Peer, ps peer.Peerstore) Listener { + l, err := Listen(ctx, addr, p, ps) + if err != nil { + t.Fatal(err) + } + return l + } + + log.Info("Setting up listeners") + p1l := listen(p1.Addresses[0], p1, p1ps) + p2l := listen(p2.Addresses[0], p2, p2ps) + + // dialers + p1d := &Dialer{Peerstore: p1ps, LocalPeer: p1} + p2d := &Dialer{Peerstore: p2ps, LocalPeer: p2} + + dial := func(d *Dialer, dst *peer.Peer) <-chan Conn { + cc := make(chan Conn) + go func() { + c, err := d.Dial(ctx, "tcp", dst) + if err != nil { + t.Fatal("error dialing peer", err) + } + cc <- c + }() + return cc + } + + // connect simultaneously + log.Info("Connecting...") + p1dc := dial(p1d, p2) + p2dc := dial(p2d, p1) + + c12a := <-p1l.Accept() + c12b := <-p1dc + c21a := <-p2l.Accept() + c21b := <-p2dc + + log.Info("Ok, making multiconns") + c1, err := NewMultiConn(ctx, p1, p2, []Conn{c12a, c12b}) + if err != nil { + t.Fatal(err) + } + + c2, err := NewMultiConn(ctx, p2, p1, []Conn{c21a, c21b}) + if err != nil { + t.Fatal(err) + } + + p1l.Close() + p2l.Close() + + log.Info("did you make multiconns?") + return c1, c2 +} + +func TestMulticonnSend(t *testing.T) { + // t.Skip("fooo") + + log.Info("TestMulticonnSend") + ctx := context.Background() + ctxC, cancel := context.WithCancel(ctx) + + c1, c2 := setupMultiConns(t, ctx) + + log.Info("gen msgs") + num := 100 + msgsFrom1 := genMessages(num, "from p1 to p2") + msgsFrom2 := genMessages(num, "from p2 to p1") + + var wg sync.WaitGroup + + send := func(c *MultiConn, msgs *msgMap) { + defer wg.Done() + + for _, m := range msgs.msgs { + log.Info("send: %s", m.payload) + c.Out() <- []byte(m.payload) + msgs.Sent(t, m.payload) + <-time.After(time.Microsecond * 10) + } + } + + recv := func(ctx context.Context, c *MultiConn, msgs *msgMap) { + defer wg.Done() + + for { + select { + case payload := <-c.In(): + msgs.Received(t, string(payload)) + log.Info("recv: %s", payload) + if msgs.recv == len(msgs.msgs) { + return + } + + case <-ctx.Done(): + return + + } + } + + } + + log.Info("msg send + recv") + + wg.Add(4) + go send(c1, msgsFrom1) + go send(c2, msgsFrom2) + go recv(ctxC, c1, msgsFrom2) + go recv(ctxC, c2, msgsFrom1) + wg.Wait() + cancel() + c1.Close() + c2.Close() + + msgsFrom1.CheckDone(t) + msgsFrom2.CheckDone(t) + <-time.After(100 * time.Millisecond) +} + +func TestMulticonnSendUnderlying(t *testing.T) { + // t.Skip("fooo") + + log.Info("TestMulticonnSendUnderlying") + ctx := context.Background() + ctxC, cancel := context.WithCancel(ctx) + + c1, c2 := setupMultiConns(t, ctx) + + log.Info("gen msgs") + num := 100 + msgsFrom1 := genMessages(num, "from p1 to p2") + msgsFrom2 := genMessages(num, "from p2 to p1") + + var wg sync.WaitGroup + + send := func(c *MultiConn, msgs *msgMap) { + defer wg.Done() + + conns := make([]Conn, 0, len(c.conns)) + for _, c1 := range c.conns { + conns = append(conns, c1) + } + + i := 0 + for _, m := range msgs.msgs { + log.Info("send: %s", m.payload) + switch i % 3 { + case 0: + conns[0].Out() <- []byte(m.payload) + case 1: + conns[1].Out() <- []byte(m.payload) + case 2: + c.Out() <- []byte(m.payload) + } + msgs.Sent(t, m.payload) + <-time.After(time.Microsecond * 10) + i++ + } + } + + recv := func(ctx context.Context, c *MultiConn, msgs *msgMap) { + defer wg.Done() + + for { + select { + case payload := <-c.In(): + msgs.Received(t, string(payload)) + log.Info("recv: %s", payload) + if msgs.recv == len(msgs.msgs) { + return + } + + case <-ctx.Done(): + return + + } + } + + } + + log.Info("msg send + recv") + + wg.Add(4) + go send(c1, msgsFrom1) + go send(c2, msgsFrom2) + go recv(ctxC, c1, msgsFrom2) + go recv(ctxC, c2, msgsFrom1) + wg.Wait() + cancel() + c1.Close() + c2.Close() + + msgsFrom1.CheckDone(t) + msgsFrom2.CheckDone(t) +} + +func TestMulticonnClose(t *testing.T) { + // t.Skip("fooo") + + log.Info("TestMulticonnSendUnderlying") + ctx := context.Background() + c1, c2 := setupMultiConns(t, ctx) + + for _, c := range c1.conns { + c.Close() + } + + for _, c := range c2.conns { + c.Close() + } + + timeout := time.After(100 * time.Millisecond) + select { + case <-c1.Closed(): + case <-timeout: + t.Fatal("timeout") + } + + select { + case <-c2.Closed(): + case <-timeout: + t.Fatal("timeout") + } +} diff --git a/net/conn/secure_conn.go b/net/conn/secure_conn.go new file mode 100644 index 00000000000..dfccbaf2e0e --- /dev/null +++ b/net/conn/secure_conn.go @@ -0,0 +1,134 @@ +package conn + +import ( + "errors" + + context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" + ma "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr" + + spipe "github.com/jbenet/go-ipfs/crypto/spipe" + peer "github.com/jbenet/go-ipfs/peer" + ctxc "github.com/jbenet/go-ipfs/util/ctxcloser" +) + +// secureConn wraps another Conn object with an encrypted channel. +type secureConn struct { + + // the wrapped conn + insecure Conn + + // secure pipe, wrapping insecure + secure *spipe.SecurePipe + + ctxc.ContextCloser +} + +// newConn constructs a new connection +func newSecureConn(ctx context.Context, insecure Conn, peers peer.Peerstore) (Conn, error) { + + conn := &secureConn{ + insecure: insecure, + } + conn.ContextCloser = ctxc.NewContextCloser(ctx, conn.close) + + log.Debug("newSecureConn: %v to %v", insecure.LocalPeer(), insecure.RemotePeer()) + // perform secure handshake before returning this connection. + if err := conn.secureHandshake(peers); err != nil { + conn.Close() + return nil, err + } + log.Debug("newSecureConn: %v to %v handshake success!", insecure.LocalPeer(), insecure.RemotePeer()) + + return conn, nil +} + +// secureHandshake performs the spipe secure handshake. +func (c *secureConn) secureHandshake(peers peer.Peerstore) error { + if c.secure != nil { + return errors.New("Conn is already secured or being secured.") + } + + // ok to panic here if this type assertion fails. Interface hack. + // when we support wrapping other Conns, we'll need to change + // spipe to do something else. + insecureSC := c.insecure.(*singleConn) + + // setup a Duplex pipe for spipe + insecureD := spipe.Duplex{ + In: insecureSC.msgio.incoming.MsgChan, + Out: insecureSC.msgio.outgoing.MsgChan, + } + + // spipe performs the secure handshake, which takes multiple RTT + sp, err := spipe.NewSecurePipe(c.Context(), 10, c.LocalPeer(), peers, insecureD) + if err != nil { + return err + } + + // assign it into the conn object + c.secure = sp + + // if we do not know RemotePeer, get it from secure chan (who identifies it) + if insecureSC.remote == nil { + insecureSC.remote = c.secure.RemotePeer() + + } else if insecureSC.remote != c.secure.RemotePeer() { + // this panic is here because this would be an insidious programmer error + // that we need to ensure we catch. + // update: this actually might happen under normal operation-- should + // perhaps return an error. TBD. + + log.Error("secureConn peer mismatch. %v != %v", insecureSC.remote, c.secure.RemotePeer()) + panic("secureConn peer mismatch. consructed incorrectly?") + } + + return nil +} + +// close is called by ContextCloser +func (c *secureConn) close() error { + err := c.insecure.Close() + if c.secure != nil { // may never have gotten here. + err = c.secure.Close() + } + return err +} + +// ID is an identifier unique to this connection. +func (c *secureConn) ID() string { + return ID(c) +} + +func (c *secureConn) String() string { + return String(c, "secureConn") +} + +// LocalMultiaddr is the Multiaddr on this side +func (c *secureConn) LocalMultiaddr() ma.Multiaddr { + return c.insecure.LocalMultiaddr() +} + +// RemoteMultiaddr is the Multiaddr on the remote side +func (c *secureConn) RemoteMultiaddr() ma.Multiaddr { + return c.insecure.RemoteMultiaddr() +} + +// LocalPeer is the Peer on this side +func (c *secureConn) LocalPeer() *peer.Peer { + return c.insecure.LocalPeer() +} + +// RemotePeer is the Peer on the remote side +func (c *secureConn) RemotePeer() *peer.Peer { + return c.insecure.RemotePeer() +} + +// In returns a readable message channel +func (c *secureConn) In() <-chan []byte { + return c.secure.In +} + +// Out returns a writable message channel +func (c *secureConn) Out() chan<- []byte { + return c.secure.Out +} diff --git a/net/conn/secure_conn_test.go b/net/conn/secure_conn_test.go new file mode 100644 index 00000000000..f7567148009 --- /dev/null +++ b/net/conn/secure_conn_test.go @@ -0,0 +1,165 @@ +package conn + +import ( + "bytes" + "fmt" + "os" + "runtime" + "strconv" + "sync" + "testing" + "time" + + peer "github.com/jbenet/go-ipfs/peer" + + context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" +) + +func setupSecureConn(t *testing.T, c Conn) Conn { + c, ok := c.(*secureConn) + if ok { + return c + } + + // shouldn't happen, because dial + listen already return secure conns. + s, err := newSecureConn(c.Context(), c, peer.NewPeerstore()) + if err != nil { + t.Fatal(err) + } + return s +} + +func TestSecureClose(t *testing.T) { + // t.Skip("Skipping in favor of another test") + + ctx, cancel := context.WithCancel(context.Background()) + c1, c2 := setupConn(t, ctx, "/ip4/127.0.0.1/tcp/6634", "/ip4/127.0.0.1/tcp/6645") + + c1 = setupSecureConn(t, c1) + c2 = setupSecureConn(t, c2) + + select { + case <-c1.Closed(): + t.Fatal("done before close") + case <-c2.Closed(): + t.Fatal("done before close") + default: + } + + c1.Close() + + select { + case <-c1.Closed(): + default: + t.Fatal("not done after close") + } + + c2.Close() + + select { + case <-c2.Closed(): + default: + t.Fatal("not done after close") + } + + cancel() // close the listener :P +} + +func TestSecureCancel(t *testing.T) { + // t.Skip("Skipping in favor of another test") + + ctx, cancel := context.WithCancel(context.Background()) + c1, c2 := setupConn(t, ctx, "/ip4/127.0.0.1/tcp/6634", "/ip4/127.0.0.1/tcp/6645") + + c1 = setupSecureConn(t, c1) + c2 = setupSecureConn(t, c2) + + select { + case <-c1.Closed(): + t.Fatal("done before close") + case <-c2.Closed(): + t.Fatal("done before close") + default: + } + + c1.Close() + c2.Close() + cancel() // listener + + // wait to ensure other goroutines run and close things. + <-time.After(time.Microsecond * 10) + // test that cancel called Close. + + select { + case <-c1.Closed(): + default: + t.Fatal("not done after cancel") + } + + select { + case <-c2.Closed(): + default: + t.Fatal("not done after cancel") + } + +} + +func TestSecureCloseLeak(t *testing.T) { + // t.Skip("Skipping in favor of another test") + if os.Getenv("TRAVIS") == "true" { + t.Skip("this doesn't work well on travis") + } + + var wg sync.WaitGroup + + runPair := func(p1, p2, num int) { + a1 := strconv.Itoa(p1) + a2 := strconv.Itoa(p2) + ctx, cancel := context.WithCancel(context.Background()) + c1, c2 := setupConn(t, ctx, "/ip4/127.0.0.1/tcp/"+a1, "/ip4/127.0.0.1/tcp/"+a2) + + c1 = setupSecureConn(t, c1) + c2 = setupSecureConn(t, c2) + + for i := 0; i < num; i++ { + b1 := []byte("beep") + c1.Out() <- b1 + b2 := <-c2.In() + if !bytes.Equal(b1, b2) { + panic("bytes not equal") + } + + b2 = []byte("boop") + c2.Out() <- b2 + b1 = <-c1.In() + if !bytes.Equal(b1, b2) { + panic("bytes not equal") + } + + <-time.After(time.Microsecond * 5) + } + + c1.Close() + c2.Close() + cancel() // close the listener + wg.Done() + } + + var cons = 20 + var msgs = 100 + fmt.Printf("Running %d connections * %d msgs.\n", cons, msgs) + for i := 0; i < cons; i++ { + wg.Add(1) + go runPair(2000+i, 2001+i, msgs) + } + + fmt.Printf("Waiting...\n") + wg.Wait() + // done! + + <-time.After(time.Millisecond * 150) + if runtime.NumGoroutine() > 20 { + // panic("uncomment me to debug") + t.Fatal("leaking goroutines:", runtime.NumGoroutine()) + } +} diff --git a/net/interface.go b/net/interface.go index dee1460fc81..379d0196805 100644 --- a/net/interface.go +++ b/net/interface.go @@ -29,6 +29,10 @@ type Network interface { // GetPeerList returns the list of peers currently connected in this network. GetPeerList() []*peer.Peer + // GetBandwidthTotals returns the total number of bytes passed through + // the network since it was instantiated + GetBandwidthTotals() (uint64, uint64) + // SendMessage sends given Message out SendMessage(msg.NetMessage) error diff --git a/net/mux/mux.go b/net/mux/mux.go index 3138fe873f9..ab325ecd55d 100644 --- a/net/mux/mux.go +++ b/net/mux/mux.go @@ -36,6 +36,12 @@ type Muxer struct { ctx context.Context wg sync.WaitGroup + bwiLock sync.Mutex + bwIn uint64 + + bwoLock sync.Mutex + bwOut uint64 + *msg.Pipe } @@ -76,6 +82,17 @@ func (m *Muxer) Start(ctx context.Context) error { return nil } +func (m *Muxer) GetBandwidthTotals() (in uint64, out uint64) { + m.bwiLock.Lock() + in = m.bwIn + m.bwiLock.Unlock() + + m.bwoLock.Lock() + out = m.bwOut + m.bwoLock.Unlock() + return +} + // Stop stops muxer activity. func (m *Muxer) Stop() { if m.cancel == nil { @@ -125,6 +142,11 @@ func (m *Muxer) handleIncomingMessages() { // handleIncomingMessage routes message to the appropriate protocol. func (m *Muxer) handleIncomingMessage(m1 msg.NetMessage) { + m.bwiLock.Lock() + // TODO: compensate for overhead + m.bwIn += uint64(len(m1.Data())) + m.bwiLock.Unlock() + data, pid, err := unwrapData(m1.Data()) if err != nil { log.Error("muxer de-serializing error: %v", err) @@ -173,6 +195,11 @@ func (m *Muxer) handleOutgoingMessage(pid ProtocolID, m1 msg.NetMessage) { return } + m.bwoLock.Lock() + // TODO: compensate for overhead + m.bwOut += uint64(len(data)) + m.bwoLock.Unlock() + m2 := msg.New(m1.Peer(), data) select { case m.GetPipe().Outgoing <- m2: diff --git a/net/net.go b/net/net.go index b5864fe68a6..9ec7d2982d4 100644 --- a/net/net.go +++ b/net/net.go @@ -111,3 +111,7 @@ func (n *IpfsNetwork) Close() error { func (n *IpfsNetwork) GetPeerList() []*peer.Peer { return n.swarm.GetPeerList() } + +func (n *IpfsNetwork) GetBandwidthTotals() (in uint64, out uint64) { + return n.muxer.GetBandwidthTotals() +} diff --git a/net/swarm/conn.go b/net/swarm/conn.go index 5aa5a304e2d..891a191a694 100644 --- a/net/swarm/conn.go +++ b/net/swarm/conn.go @@ -4,14 +4,10 @@ import ( "errors" "fmt" - spipe "github.com/jbenet/go-ipfs/crypto/spipe" conn "github.com/jbenet/go-ipfs/net/conn" - handshake "github.com/jbenet/go-ipfs/net/handshake" msg "github.com/jbenet/go-ipfs/net/message" - proto "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto" ma "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr" - manet "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr/net" ) // Open listeners for each network the swarm should listen on @@ -39,28 +35,35 @@ func (s *Swarm) listen() error { // Listen for new connections on the given multiaddr func (s *Swarm) connListen(maddr ma.Multiaddr) error { - list, err := manet.Listen(maddr) + + list, err := conn.Listen(s.Context(), maddr, s.local, s.peers) if err != nil { return err } + // make sure port can be reused. TOOD this doesn't work... + // if err := setSocketReuse(list); err != nil { + // return err + // } + // NOTE: this may require a lock around it later. currently, only run on setup s.listeners = append(s.listeners, list) // Accept and handle new connections on this listener until it errors + // this listener is a child. + s.Children().Add(1) go func() { + defer s.Children().Done() + for { - nconn, err := list.Accept() - if err != nil { - e := fmt.Errorf("Failed to accept connection: %s - %s", maddr, err) - s.errChan <- e + select { + case <-s.Closing(): + return - // if cancel is nil, we're closed. - if s.cancel == nil { - return - } - } else { - go s.handleIncomingConn(nconn) + case conn := <-list.Accept(): + // handler also a child. + s.Children().Add(1) + go s.handleIncomingConn(conn) } } }() @@ -69,202 +72,158 @@ func (s *Swarm) connListen(maddr ma.Multiaddr) error { } // Handle getting ID from this peer, handshake, and adding it into the map -func (s *Swarm) handleIncomingConn(nconn manet.Conn) { - - addr := nconn.RemoteMultiaddr() - - // Construct conn with nil peer for now, because we don't know its ID yet. - // connSetup will figure this out, and pull out / construct the peer. - c, err := conn.NewConn(nil, addr, nconn) - if err != nil { - s.errChan <- err - return - } +func (s *Swarm) handleIncomingConn(nconn conn.Conn) { + // this handler is a child. added by caller. + defer s.Children().Done() // Setup the new connection - err = s.connSetup(c) + _, err := s.connSetup(nconn) if err != nil && err != ErrAlreadyOpen { s.errChan <- err - c.Close() + nconn.Close() } } // connSetup adds the passed in connection to its peerMap and starts -// the fanIn routine for that connection -func (s *Swarm) connSetup(c *conn.Conn) error { +// the fanInSingle routine for that connection +func (s *Swarm) connSetup(c conn.Conn) (conn.Conn, error) { if c == nil { - return errors.New("Tried to start nil connection.") - } - - if c.Peer != nil { - log.Debug("Starting connection: %s", c.Peer) - } else { - log.Debug("Starting connection: [unknown peer]") + return nil, errors.New("Tried to start nil connection.") } - if err := s.connSecure(c); err != nil { - return fmt.Errorf("Conn securing error: %v", err) - } - - log.Debug("Secured connection: %s", c.Peer) + log.Debug("%s Started connection: %s", c.LocalPeer(), c.RemotePeer()) // add address of connection to Peer. Maybe it should happen in connSecure. - c.Peer.AddAddress(c.Addr) - - if err := s.connVersionExchange(c); err != nil { - return fmt.Errorf("Conn version exchange error: %v", err) - } + // NOT adding this address here, because the incoming address in TCP + // is an EPHEMERAL address, and not the address we want to keep around. + // addresses should be figured out through the DHT. + // c.Remote.AddAddress(c.Conn.RemoteMultiaddr()) // add to conns s.connsLock.Lock() - if _, ok := s.conns[c.Peer.Key()]; ok { - log.Debug("Conn already open!") - s.connsLock.Unlock() - return ErrAlreadyOpen - } - s.conns[c.Peer.Key()] = c - log.Debug("Added conn to map!") - s.connsLock.Unlock() - - // kick off reader goroutine - go s.fanIn(c) - return nil -} - -// connSecure setups a secure remote connection. -func (s *Swarm) connSecure(c *conn.Conn) error { - - sp, err := spipe.NewSecurePipe(s.ctx, 10, s.local, s.peers) - if err != nil { - return err - } - - err = sp.Wrap(s.ctx, spipe.Duplex{ - In: c.Incoming.MsgChan, - Out: c.Outgoing.MsgChan, - }) - if err != nil { - return err - } - - if c.Peer == nil { - c.Peer = sp.RemotePeer() - - } else if c.Peer != sp.RemotePeer() { - panic("peers not being constructed correctly.") - } - c.Secure = sp - return nil -} - -// connVersionExchange exchanges local and remote versions and compares them -// closes remote and returns an error in case of major difference -func (s *Swarm) connVersionExchange(remote *conn.Conn) error { - var remoteHandshake, localHandshake *handshake.Handshake1 - localHandshake = handshake.CurrentHandshake() - - myVerBytes, err := proto.Marshal(localHandshake) - if err != nil { - return err - } - - remote.Secure.Out <- myVerBytes - - log.Debug("Send my version(%s) [to = %s]", localHandshake, remote.Peer) - - select { - case <-s.ctx.Done(): - return s.ctx.Err() - - case <-remote.Closed: - return errors.New("remote closed connection during version exchange") - - case data, ok := <-remote.Secure.In: - if !ok { - return fmt.Errorf("Error retrieving from conn: %v", remote.Peer) - } - - remoteHandshake = new(handshake.Handshake1) - err = proto.Unmarshal(data, remoteHandshake) + mc, found := s.conns[c.RemotePeer().Key()] + if !found { + // multiconn doesn't exist, make a new one. + conns := []conn.Conn{c} + mc, err := conn.NewMultiConn(s.Context(), s.local, c.RemotePeer(), conns) if err != nil { - s.Close() - return fmt.Errorf("connSetup: could not decode remote version: %q", err) + log.Error("error creating multiconn: %s", err) + c.Close() + return nil, err } - log.Debug("Received remote version(%s) [from = %s]", remoteHandshake, remote.Peer) - } + s.conns[c.RemotePeer().Key()] = mc + s.connsLock.Unlock() - if err := handshake.Compatible(localHandshake, remoteHandshake); err != nil { - log.Info("%s (%s) incompatible version with %s (%s)", s.local, localHandshake, remote.Peer, remoteHandshake) - remote.Close() - return err + // kick off reader goroutine + go s.fanInSingle(mc) + log.Debug("added new multiconn: %s", mc) + } else { + s.connsLock.Unlock() // unlock before adding new conn + + mc.Add(c) + log.Debug("multiconn found: %s", mc) } - log.Debug("[peer: %s] Version compatible", remote.Peer) - return nil + log.Debug("multiconn added new conn %s", c) + return c, nil } // Handles the unwrapping + sending of messages to the right connection. func (s *Swarm) fanOut() { + s.Children().Add(1) + defer s.Children().Done() + + i := 0 for { select { - case <-s.ctx.Done(): + case <-s.Closing(): return // told to close. case msg, ok := <-s.Outgoing: if !ok { + log.Info("%s outgoing channel closed", s) return } s.connsLock.RLock() - conn, found := s.conns[msg.Peer().Key()] + c, found := s.conns[msg.Peer().Key()] s.connsLock.RUnlock() if !found { - e := fmt.Errorf("Sent msg to peer without open conn: %v", - msg.Peer) + e := fmt.Errorf("Sent msg to peer without open conn: %v", msg.Peer()) s.errChan <- e + log.Error("%s", e) continue } - // log.Debug("[peer: %s] Sent message [to = %s]", s.local, msg.Peer()) - + i++ + log.Debug("%s sent message to %s (%d)", s.local, msg.Peer(), i) // queue it in the connection's buffer - conn.Secure.Out <- msg.Data() + c.Out() <- msg.Data() } } } // Handles the receiving + wrapping of messages, per conn. // Consider using reflect.Select with one goroutine instead of n. -func (s *Swarm) fanIn(c *conn.Conn) { +func (s *Swarm) fanInSingle(c conn.Conn) { + s.Children().Add(1) + c.Children().Add(1) // child of Conn as well. + + // cleanup all data associated with this child Connection. + defer func() { + // remove it from the map. + s.connsLock.Lock() + delete(s.conns, c.RemotePeer().Key()) + s.connsLock.Unlock() + + s.Children().Done() + c.Children().Done() // child of Conn as well. + }() + + i := 0 for { select { - case <-s.ctx.Done(): - // close Conn. - c.Close() - goto out + case <-s.Closing(): // Swarm closing + return - case <-c.Closed: - goto out + case <-c.Closing(): // Conn closing + return - case data, ok := <-c.Secure.In: + case data, ok := <-c.In(): if !ok { - e := fmt.Errorf("Error retrieving from conn: %v", c.Peer) - s.errChan <- e - goto out + log.Info("%s in channel closed", c) + return // channel closed. } - - // log.Debug("[peer: %s] Received message [from = %s]", s.local, c.Peer) - - msg := msg.New(c.Peer, data) - s.Incoming <- msg + i++ + log.Debug("%s received message from %s (%d)", s.local, c.RemotePeer(), i) + s.Incoming <- msg.New(c.RemotePeer(), data) } } - -out: - s.connsLock.Lock() - delete(s.conns, c.Peer.Key()) - s.connsLock.Unlock() } + +// Commenting out because it's platform specific +// func setSocketReuse(l manet.Listener) error { +// nl := l.NetListener() +// +// // for now only TCP. TODO change this when more networks. +// file, err := nl.(*net.TCPListener).File() +// if err != nil { +// return err +// } +// +// fd := file.Fd() +// err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) +// if err != nil { +// return err +// } +// +// err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEPORT, 1) +// if err != nil { +// return err +// } +// +// return nil +// } diff --git a/net/swarm/simul_test.go b/net/swarm/simul_test.go new file mode 100644 index 00000000000..2cffd0d2c61 --- /dev/null +++ b/net/swarm/simul_test.go @@ -0,0 +1,76 @@ +package swarm + +import ( + "fmt" + "sync" + "testing" + + peer "github.com/jbenet/go-ipfs/peer" + + context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" +) + +func TestSimultOpen(t *testing.T) { + // t.Skip("skipping for another test") + + addrs := []string{ + "/ip4/127.0.0.1/tcp/1244", + "/ip4/127.0.0.1/tcp/1245", + } + + ctx := context.Background() + swarms, _ := makeSwarms(ctx, t, addrs) + + // connect everyone + { + var wg sync.WaitGroup + connect := func(s *Swarm, dst *peer.Peer) { + // copy for other peer + cp := &peer.Peer{ID: dst.ID} + cp.AddAddress(dst.Addresses[0]) + + if _, err := s.Dial(cp); err != nil { + t.Fatal("error swarm dialing to peer", err) + } + wg.Done() + } + + log.Info("Connecting swarms simultaneously.") + wg.Add(2) + go connect(swarms[0], swarms[1].local) + go connect(swarms[1], swarms[0].local) + wg.Wait() + } + + for _, s := range swarms { + s.Close() + } +} + +func TestSimultOpenMany(t *testing.T) { + t.Skip("laggy") + + many := 500 + addrs := []string{} + for i := 2200; i < (2200 + many); i++ { + s := fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", i) + addrs = append(addrs, s) + } + + SubtestSwarm(t, addrs, 10) +} + +func TestSimultOpenFewStress(t *testing.T) { + // t.Skip("skipping for another test") + + num := 10 + // num := 100 + for i := 0; i < num; i++ { + addrs := []string{ + fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", 1900+i), + fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", 2900+i), + } + + SubtestSwarm(t, addrs, 10) + } +} diff --git a/net/swarm/swarm.go b/net/swarm/swarm.go index 057e4ad2609..157a9ff9238 100644 --- a/net/swarm/swarm.go +++ b/net/swarm/swarm.go @@ -9,10 +9,9 @@ import ( msg "github.com/jbenet/go-ipfs/net/message" peer "github.com/jbenet/go-ipfs/peer" u "github.com/jbenet/go-ipfs/util" + ctxc "github.com/jbenet/go-ipfs/util/ctxcloser" context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" - ma "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr" - manet "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr/net" ) var log = u.Logger("swarm") @@ -58,48 +57,42 @@ type Swarm struct { errChan chan error // conns are the open connections the swarm is handling. - conns conn.Map + // these are MultiConns, which multiplex multiple separate underlying Conns. + conns conn.MultiConnMap connsLock sync.RWMutex // listeners for each network address - listeners []manet.Listener + listeners []conn.Listener - // cancel is an internal function used to stop the Swarm's processing. - cancel context.CancelFunc - ctx context.Context + // ContextCloser + ctxc.ContextCloser } // NewSwarm constructs a Swarm, with a Chan. func NewSwarm(ctx context.Context, local *peer.Peer, ps peer.Peerstore) (*Swarm, error) { s := &Swarm{ Pipe: msg.NewPipe(10), - conns: conn.Map{}, + conns: conn.MultiConnMap{}, local: local, peers: ps, errChan: make(chan error, 100), } - s.ctx, s.cancel = context.WithCancel(ctx) + // ContextCloser for proper child management. + s.ContextCloser = ctxc.NewContextCloser(ctx, s.close) + go s.fanOut() return s, s.listen() } -// Close stops a swarm. -func (s *Swarm) Close() error { - if s.cancel == nil { - return errors.New("Swarm already closed.") - } - - // issue cancel for the context - s.cancel() - - // set cancel to nil to prevent calling Close again, and signal to Listeners - s.cancel = nil - +// close stops a swarm. It's the underlying function called by ContextCloser +func (s *Swarm) close() error { // close listeners for _, list := range s.listeners { list.Close() } + // close connections + conn.CloseConns(s.Connections()...) return nil } @@ -111,7 +104,7 @@ func (s *Swarm) Close() error { // etc. to achive connection. // // For now, Dial uses only TCP. This will be extended. -func (s *Swarm) Dial(peer *peer.Peer) (*conn.Conn, error) { +func (s *Swarm) Dial(peer *peer.Peer) (conn.Conn, error) { if peer.ID.Equal(s.local.ID) { return nil, errors.New("Attempted connection to self!") } @@ -129,45 +122,27 @@ func (s *Swarm) Dial(peer *peer.Peer) (*conn.Conn, error) { } // open connection to peer - c, err = conn.Dial("tcp", peer) - if err != nil { - return nil, err - } - - if err := s.connSetup(c); err != nil { - c.Close() - return nil, err - } - - return c, nil -} - -// DialAddr is for connecting to a peer when you know their addr but not their ID. -// Should only be used when sure that not connected to peer in question -// TODO(jbenet) merge with Dial? need way to patch back. -func (s *Swarm) DialAddr(addr ma.Multiaddr) (*conn.Conn, error) { - if addr == nil { - return nil, errors.New("addr must be a non-nil Multiaddr") + d := &conn.Dialer{ + LocalPeer: s.local, + Peerstore: s.peers, } - npeer := new(peer.Peer) - npeer.AddAddress(addr) - - c, err := conn.Dial("tcp", npeer) + c, err = d.Dial(s.Context(), "tcp", peer) if err != nil { return nil, err } - if err := s.connSetup(c); err != nil { + c, err = s.connSetup(c) + if err != nil { c.Close() return nil, err } - return c, err + return c, nil } // GetConnection returns the connection in the swarm to given peer.ID -func (s *Swarm) GetConnection(pid peer.ID) *conn.Conn { +func (s *Swarm) GetConnection(pid peer.ID) conn.Conn { s.connsLock.RLock() c, found := s.conns[u.Key(pid)] s.connsLock.RUnlock() @@ -178,6 +153,19 @@ func (s *Swarm) GetConnection(pid peer.ID) *conn.Conn { return c } +// Connections returns a slice of all connections. +func (s *Swarm) Connections() []conn.Conn { + s.connsLock.RLock() + + conns := make([]conn.Conn, 0, len(s.conns)) + for _, c := range s.conns { + conns = append(conns, c) + } + + s.connsLock.RUnlock() + return conns +} + // CloseConnection removes a given peer from swarm + closes the connection func (s *Swarm) CloseConnection(p *peer.Peer) error { c := s.GetConnection(p.ID) @@ -201,11 +189,12 @@ func (s *Swarm) GetErrChan() chan error { return s.errChan } +// GetPeerList returns a copy of the set of peers swarm is connected to. func (s *Swarm) GetPeerList() []*peer.Peer { var out []*peer.Peer s.connsLock.RLock() for _, p := range s.conns { - out = append(out, p.Peer) + out = append(out, p.RemotePeer()) } s.connsLock.RUnlock() return out diff --git a/net/swarm/swarm_test.go b/net/swarm/swarm_test.go index 88de9198d2e..d920b6b87d1 100644 --- a/net/swarm/swarm_test.go +++ b/net/swarm/swarm_test.go @@ -1,147 +1,187 @@ package swarm import ( - "fmt" + "bytes" + "sync" "testing" + "time" + ci "github.com/jbenet/go-ipfs/crypto" msg "github.com/jbenet/go-ipfs/net/message" peer "github.com/jbenet/go-ipfs/peer" u "github.com/jbenet/go-ipfs/util" context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" - msgio "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-msgio" ma "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr" - manet "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr/net" - mh "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multihash" ) -func pingListen(t *testing.T, listener manet.Listener, peer *peer.Peer) { +func pong(ctx context.Context, swarm *Swarm) { + i := 0 for { - c, err := listener.Accept() - if err == nil { - go pong(t, c, peer) - } - } -} - -func pong(t *testing.T, c manet.Conn, peer *peer.Peer) { - mrw := msgio.NewReadWriter(c) - for { - data := make([]byte, 1024) - n, err := mrw.ReadMsg(data) - if err != nil { - fmt.Printf("error %v\n", err) - return - } - d := string(data[:n]) - if d != "ping" { - t.Errorf("error: didn't receive ping: '%v'\n", d) - return - } - err = mrw.WriteMsg([]byte("pong")) - if err != nil { - fmt.Printf("error %v\n", err) + select { + case <-ctx.Done(): return + case m1 := <-swarm.Incoming: + if bytes.Equal(m1.Data(), []byte("ping")) { + m2 := msg.New(m1.Peer(), []byte("pong")) + i++ + log.Debug("%s pong %s (%d)", swarm.local, m1.Peer(), i) + swarm.Outgoing <- m2 + } } } } -func setupPeer(id string, addr string) (*peer.Peer, error) { +func setupPeer(t *testing.T, addr string) *peer.Peer { tcp, err := ma.NewMultiaddr(addr) if err != nil { - return nil, err + t.Fatal(err) } - mh, err := mh.FromHexString(id) + sk, pk, err := ci.GenerateKeyPair(ci.RSA, 512) if err != nil { - return nil, err + t.Fatal(err) } - p := &peer.Peer{ID: peer.ID(mh)} + id, err := peer.IDFromPubKey(pk) + if err != nil { + t.Fatal(err) + } + + p := &peer.Peer{ID: id} + p.PrivKey = sk + p.PubKey = pk p.AddAddress(tcp) - return p, nil + return p } -func TestSwarm(t *testing.T) { - t.Skip("TODO FIXME nil pointer") +func makeSwarms(ctx context.Context, t *testing.T, addrs []string) ([]*Swarm, []*peer.Peer) { + swarms := []*Swarm{} - local, err := setupPeer("11140beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a30", - "/ip4/127.0.0.1/tcp/1234") - if err != nil { - t.Fatal("error setting up peer", err) + for _, addr := range addrs { + local := setupPeer(t, addr) + peerstore := peer.NewPeerstore() + swarm, err := NewSwarm(ctx, local, peerstore) + if err != nil { + t.Fatal(err) + } + swarms = append(swarms, swarm) } - peerstore := peer.NewPeerstore() - - swarm, err := NewSwarm(context.Background(), local, peerstore) - if err != nil { - t.Error(err) - } - var peers []*peer.Peer - var listeners []manet.Listener - peerNames := map[string]string{ - "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a31": "/ip4/127.0.0.1/tcp/2345", - "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a32": "/ip4/127.0.0.1/tcp/3456", - "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33": "/ip4/127.0.0.1/tcp/4567", - "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a34": "/ip4/127.0.0.1/tcp/5678", + peers := make([]*peer.Peer, len(swarms)) + for i, s := range swarms { + peers[i] = s.local } - for k, n := range peerNames { - peer, err := setupPeer(k, n) - if err != nil { - t.Fatal("error setting up peer", err) + return swarms, peers +} + +func SubtestSwarm(t *testing.T, addrs []string, MsgNum int) { + // t.Skip("skipping for another test") + + ctx := context.Background() + swarms, peers := makeSwarms(ctx, t, addrs) + + // connect everyone + { + var wg sync.WaitGroup + connect := func(s *Swarm, dst *peer.Peer) { + // copy for other peer + + cp, err := s.peers.Get(dst.ID) + if err != nil { + cp = &peer.Peer{ID: dst.ID} + } + cp.AddAddress(dst.Addresses[0]) + + log.Info("SWARM TEST: %s dialing %s", s.local, dst) + if _, err := s.Dial(cp); err != nil { + t.Fatal("error swarm dialing to peer", err) + } + log.Info("SWARM TEST: %s connected to %s", s.local, dst) + wg.Done() } - a := peer.NetAddress("tcp") - if a == nil { - t.Fatal("error setting up peer (addr is nil)", peer) + + log.Info("Connecting swarms simultaneously.") + for _, s := range swarms { + for _, p := range peers { + if p != s.local { // don't connect to self. + wg.Add(1) + connect(s, p) + } + } } - listener, err := manet.Listen(a) - if err != nil { - t.Fatal("error setting up listener", err) + wg.Wait() + } + + // ping/pong + for _, s1 := range swarms { + ctx, cancel := context.WithCancel(ctx) + + // setup all others to pong + for _, s2 := range swarms { + if s1 == s2 { + continue + } + + go pong(ctx, s2) } - go pingListen(t, listener, peer) - _, err = swarm.Dial(peer) + peers, err := s1.peers.All() if err != nil { - t.Fatal("error swarm dialing to peer", err) + t.Fatal(err) } - // ok done, add it. - peers = append(peers, peer) - listeners = append(listeners, listener) - } + for k := 0; k < MsgNum; k++ { + for _, p := range *peers { + log.Debug("%s ping %s (%d)", s1.local, p, k) + s1.Outgoing <- msg.New(p, []byte("ping")) + } + } + + got := map[u.Key]int{} + for k := 0; k < (MsgNum * len(*peers)); k++ { + log.Debug("%s waiting for pong (%d)", s1.local, k) + msg := <-s1.Incoming + if string(msg.Data()) != "pong" { + t.Error("unexpected conn output", msg.Data) + } - MsgNum := 1000 - for k := 0; k < MsgNum; k++ { - for _, p := range peers { - swarm.Outgoing <- msg.New(p, []byte("ping")) + n, _ := got[msg.Peer().Key()] + got[msg.Peer().Key()] = n + 1 } - } - got := map[u.Key]int{} + if len(*peers) != len(got) { + t.Error("got less messages than sent") + } - for k := 0; k < (MsgNum * len(peers)); k++ { - msg := <-swarm.Incoming - if string(msg.Data()) != "pong" { - t.Error("unexpected conn output", msg.Data) + for p, n := range got { + if n != MsgNum { + t.Error("peer did not get all msgs", p, n, "/", MsgNum) + } } - n, _ := got[msg.Peer().Key()] - got[msg.Peer().Key()] = n + 1 + cancel() + <-time.After(50 * time.Microsecond) } - if len(peers) != len(got) { - t.Error("got less messages than sent") + for _, s := range swarms { + s.Close() } +} - for p, n := range got { - if n != MsgNum { - t.Error("peer did not get all msgs", p, n, "/", MsgNum) - } +func TestSwarm(t *testing.T) { + // t.Skip("skipping for another test") + + addrs := []string{ + "/ip4/127.0.0.1/tcp/10234", + "/ip4/127.0.0.1/tcp/10235", + "/ip4/127.0.0.1/tcp/10236", + "/ip4/127.0.0.1/tcp/10237", + "/ip4/127.0.0.1/tcp/10238", } - swarm.Close() - for _, listener := range listeners { - listener.Close() - } + // msgs := 1000 + msgs := 100 + SubtestSwarm(t, addrs, msgs) } diff --git a/peer/peer.go b/peer/peer.go index 0961f340841..6039e1fac53 100644 --- a/peer/peer.go +++ b/peer/peer.go @@ -36,6 +36,16 @@ func DecodePrettyID(s string) ID { return b58.Decode(s) } +// IDFromPubKey retrieves a Public Key from the peer given by pk +func IDFromPubKey(pk ic.PubKey) (ID, error) { + b, err := pk.Bytes() + if err != nil { + return nil, err + } + hash := u.Hash(b) + return ID(hash), nil +} + // Map maps Key (string) : *Peer (slices are not comparable). type Map map[u.Key]*Peer @@ -55,7 +65,7 @@ type Peer struct { // String prints out the peer. func (p *Peer) String() string { - return "[Peer " + p.ID.String() + "]" + return "[Peer " + p.ID.String()[:12] + "]" } // Key returns the ID as a Key (string) for maps. diff --git a/routing/dht/Message.go b/routing/dht/Message.go index e4607f1de3e..526724287b9 100644 --- a/routing/dht/Message.go +++ b/routing/dht/Message.go @@ -1,7 +1,10 @@ package dht import ( + "errors" + "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto" + ma "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr" peer "github.com/jbenet/go-ipfs/peer" ) @@ -35,6 +38,14 @@ func peersToPBPeers(peers []*peer.Peer) []*Message_Peer { return pbpeers } +// Address returns a multiaddr associated with the Message_Peer entry +func (m *Message_Peer) Address() (ma.Multiaddr, error) { + if m == nil { + return nil, errors.New("MessagePeer is nil") + } + return ma.NewMultiaddr(*m.Addr) +} + // GetClusterLevel gets and adjusts the cluster level on the message. // a +/- 1 adjustment is needed to distinguish a valid first level (1) and // default "no value" protobuf behavior (0) diff --git a/routing/dht/dht.go b/routing/dht/dht.go index c95e0751136..2d4c9852e87 100644 --- a/routing/dht/dht.go +++ b/routing/dht/dht.go @@ -23,6 +23,8 @@ import ( var log = u.Logger("dht") +const doPinging = true + // TODO. SEE https://github.com/jbenet/node-ipfs/blob/master/submodules/ipfs-dht/index.js // IpfsDHT is an implementation of Kademlia with Coral and S/Kademlia modifications. @@ -53,16 +55,19 @@ type IpfsDHT struct { //lock to make diagnostics work better diaglock sync.Mutex + + ctx context.Context } // NewDHT creates a new DHT object with the given peer as the 'local' host -func NewDHT(p *peer.Peer, ps peer.Peerstore, net inet.Network, sender inet.Sender, dstore ds.Datastore) *IpfsDHT { +func NewDHT(ctx context.Context, p *peer.Peer, ps peer.Peerstore, net inet.Network, sender inet.Sender, dstore ds.Datastore) *IpfsDHT { dht := new(IpfsDHT) dht.network = net dht.sender = sender dht.datastore = dstore dht.self = p dht.peerstore = ps + dht.ctx = ctx dht.providers = NewProviderManager(p.ID) @@ -71,12 +76,16 @@ func NewDHT(p *peer.Peer, ps peer.Peerstore, net inet.Network, sender inet.Sende dht.routingTables[1] = kb.NewRoutingTable(20, kb.ConvertPeerID(p.ID), time.Millisecond*1000) dht.routingTables[2] = kb.NewRoutingTable(20, kb.ConvertPeerID(p.ID), time.Hour) dht.birth = time.Now() + + if doPinging { + go dht.PingRoutine(time.Second * 10) + } return dht } // Connect to a new peer at the given address, ping and add to the routing table func (dht *IpfsDHT) Connect(ctx context.Context, npeer *peer.Peer) (*peer.Peer, error) { - log.Debug("Connect to new peer: %s\n", npeer) + log.Debug("Connect to new peer: %s", npeer) // TODO(jbenet,whyrusleeping) // @@ -109,13 +118,13 @@ func (dht *IpfsDHT) HandleMessage(ctx context.Context, mes msg.NetMessage) msg.N mData := mes.Data() if mData == nil { - // TODO handle/log err + log.Error("Message contained nil data.") return nil } mPeer := mes.Peer() if mPeer == nil { - // TODO handle/log err + log.Error("Message contained nil peer.") return nil } @@ -123,7 +132,7 @@ func (dht *IpfsDHT) HandleMessage(ctx context.Context, mes msg.NetMessage) msg.N pmes := new(Message) err := proto.Unmarshal(mData, pmes) if err != nil { - // TODO handle/log err + log.Error("Error unmarshaling data") return nil } @@ -137,26 +146,27 @@ func (dht *IpfsDHT) HandleMessage(ctx context.Context, mes msg.NetMessage) msg.N // get handler for this msg type. handler := dht.handlerForMsgType(pmes.GetType()) if handler == nil { - // TODO handle/log err + log.Error("got back nil handler from handlerForMsgType") return nil } // dispatch handler. rpmes, err := handler(mPeer, pmes) if err != nil { - // TODO handle/log err + log.Error("handle message error: %s", err) return nil } // if nil response, return it before serializing if rpmes == nil { + log.Warning("Got back nil response from request.") return nil } // serialize response msg rmes, err := msg.FromObject(mPeer, rpmes) if err != nil { - // TODO handle/log err + log.Error("serialze response error: %s", err) return nil } @@ -197,6 +207,7 @@ func (dht *IpfsDHT) sendRequest(ctx context.Context, p *peer.Peer, pmes *Message return rpmes, nil } +// putValueToNetwork stores the given key/value pair at the peer 'p' func (dht *IpfsDHT) putValueToNetwork(ctx context.Context, p *peer.Peer, key string, value []byte) error { @@ -216,13 +227,17 @@ func (dht *IpfsDHT) putValueToNetwork(ctx context.Context, p *peer.Peer, func (dht *IpfsDHT) putProvider(ctx context.Context, p *peer.Peer, key string) error { pmes := newMessage(Message_ADD_PROVIDER, string(key), 0) + + // add self as the provider + pmes.ProviderPeers = peersToPBPeers([]*peer.Peer{dht.self}) + rpmes, err := dht.sendRequest(ctx, p, pmes) if err != nil { return err } log.Debug("%s putProvider: %s for %s", dht.self, p, key) - if *rpmes.Key != *pmes.Key { + if rpmes.GetKey() != pmes.GetKey() { return errors.New("provider not added correctly") } @@ -257,23 +272,11 @@ func (dht *IpfsDHT) getValueOrPeers(ctx context.Context, p *peer.Peer, // Perhaps we were given closer peers var peers []*peer.Peer for _, pb := range pmes.GetCloserPeers() { - if peer.ID(pb.GetId()).Equal(dht.self.ID) { - continue - } - - addr, err := ma.NewMultiaddr(pb.GetAddr()) + pr, err := dht.addPeer(pb) if err != nil { - log.Error("%v", err.Error()) + log.Error("%s", err) continue } - - // check if we already have this peer. - pr, _ := dht.peerstore.Get(peer.ID(pb.GetId())) - if pr == nil { - pr = &peer.Peer{ID: peer.ID(pb.GetId())} - dht.peerstore.Put(pr) - } - pr.AddAddress(addr) // idempotent peers = append(peers, pr) } @@ -286,6 +289,27 @@ func (dht *IpfsDHT) getValueOrPeers(ctx context.Context, p *peer.Peer, return nil, nil, u.ErrNotFound } +func (dht *IpfsDHT) addPeer(pb *Message_Peer) (*peer.Peer, error) { + if peer.ID(pb.GetId()).Equal(dht.self.ID) { + return nil, errors.New("cannot add self as peer") + } + + addr, err := ma.NewMultiaddr(pb.GetAddr()) + if err != nil { + return nil, err + } + + // check if we already have this peer. + pr, _ := dht.peerstore.Get(peer.ID(pb.GetId())) + if pr == nil { + pr = &peer.Peer{ID: peer.ID(pb.GetId())} + dht.peerstore.Put(pr) + } + pr.AddAddress(addr) // idempotent + + return pr, nil +} + // getValueSingle simply performs the get value RPC with the given parameters func (dht *IpfsDHT) getValueSingle(ctx context.Context, p *peer.Peer, key u.Key, level int) (*Message, error) { @@ -323,6 +347,7 @@ func (dht *IpfsDHT) getFromPeerList(ctx context.Context, key u.Key, return nil, u.ErrNotFound } +// getLocal attempts to retrieve the value from the datastore func (dht *IpfsDHT) getLocal(key u.Key) ([]byte, error) { dht.dslock.Lock() defer dht.dslock.Unlock() @@ -333,11 +358,12 @@ func (dht *IpfsDHT) getLocal(key u.Key) ([]byte, error) { byt, ok := v.([]byte) if !ok { - return byt, errors.New("value stored in datastore not []byte") + return nil, errors.New("value stored in datastore not []byte") } return byt, nil } +// putLocal stores the key value pair in the datastore func (dht *IpfsDHT) putLocal(key u.Key, value []byte) error { return dht.datastore.Put(key.DsKey(), value) } @@ -364,8 +390,8 @@ func (dht *IpfsDHT) Update(p *peer.Peer) { // after some deadline of inactivity. } -// Find looks for a peer with a given ID connected to this dht and returns the peer and the table it was found in. -func (dht *IpfsDHT) Find(id peer.ID) (*peer.Peer, *kb.RoutingTable) { +// FindLocal looks for a peer with a given ID connected to this dht and returns the peer and the table it was found in. +func (dht *IpfsDHT) FindLocal(id peer.ID) (*peer.Peer, *kb.RoutingTable) { for _, table := range dht.routingTables { p := table.Find(id) if p != nil { @@ -415,39 +441,44 @@ func (dht *IpfsDHT) addProviders(key u.Key, peers []*Message_Peer) []*peer.Peer return provArr } -// nearestPeerToQuery returns the routing tables closest peers. -func (dht *IpfsDHT) nearestPeerToQuery(pmes *Message) *peer.Peer { +// nearestPeersToQuery returns the routing tables closest peers. +func (dht *IpfsDHT) nearestPeersToQuery(pmes *Message, count int) []*peer.Peer { level := pmes.GetClusterLevel() cluster := dht.routingTables[level] key := u.Key(pmes.GetKey()) - closer := cluster.NearestPeer(kb.ConvertKey(key)) + closer := cluster.NearestPeers(kb.ConvertKey(key), count) return closer } -// betterPeerToQuery returns nearestPeerToQuery, but iff closer than self. -func (dht *IpfsDHT) betterPeerToQuery(pmes *Message) *peer.Peer { - closer := dht.nearestPeerToQuery(pmes) +// betterPeerToQuery returns nearestPeersToQuery, but iff closer than self. +func (dht *IpfsDHT) betterPeersToQuery(pmes *Message, count int) []*peer.Peer { + closer := dht.nearestPeersToQuery(pmes, count) // no node? nil if closer == nil { return nil } - // == to self? nil - if closer.ID.Equal(dht.self.ID) { - log.Error("Attempted to return self! this shouldnt happen...") - return nil + // == to self? thats bad + for _, p := range closer { + if p.ID.Equal(dht.self.ID) { + log.Error("Attempted to return self! this shouldnt happen...") + return nil + } } - // self is closer? nil - key := u.Key(pmes.GetKey()) - if kb.Closer(dht.self.ID, closer.ID, key) { - return nil + var filtered []*peer.Peer + for _, p := range closer { + // must all be closer than self + key := u.Key(pmes.GetKey()) + if !kb.Closer(dht.self.ID, p.ID, key) { + filtered = append(filtered, p) + } } - // ok seems like a closer node. - return closer + // ok seems like closer nodes + return filtered } func (dht *IpfsDHT) peerFromInfo(pbp *Message_Peer) (*peer.Peer, error) { @@ -461,14 +492,14 @@ func (dht *IpfsDHT) peerFromInfo(pbp *Message_Peer) (*peer.Peer, error) { p, _ := dht.peerstore.Get(id) if p == nil { - p, _ = dht.Find(id) + p, _ = dht.FindLocal(id) if p != nil { panic("somehow peer not getting into peerstore") } } if p == nil { - maddr, err := ma.NewMultiaddr(pbp.GetAddr()) + maddr, err := pbp.Address() if err != nil { return nil, err } @@ -477,6 +508,7 @@ func (dht *IpfsDHT) peerFromInfo(pbp *Message_Peer) (*peer.Peer, error) { p = &peer.Peer{ID: id} p.AddAddress(maddr) dht.peerstore.Put(p) + log.Info("dht found new peer: %s %s", p, maddr) } return p, nil } @@ -509,9 +541,33 @@ func (dht *IpfsDHT) loadProvidableKeys() error { return nil } +func (dht *IpfsDHT) PingRoutine(t time.Duration) { + tick := time.Tick(t) + for { + select { + case <-tick: + id := make([]byte, 16) + rand.Read(id) + peers := dht.routingTables[0].NearestPeers(kb.ConvertKey(u.Key(id)), 5) + for _, p := range peers { + ctx, _ := context.WithTimeout(dht.ctx, time.Second*5) + err := dht.Ping(ctx, p) + if err != nil { + log.Error("Ping error: %s", err) + } + } + case <-dht.ctx.Done(): + return + } + } +} + // Bootstrap builds up list of peers by requesting random peer IDs func (dht *IpfsDHT) Bootstrap(ctx context.Context) { id := make([]byte, 16) rand.Read(id) - dht.FindPeer(ctx, peer.ID(id)) + _, err := dht.FindPeer(ctx, peer.ID(id)) + if err != nil { + log.Error("Bootstrap peer error: %s", err) + } } diff --git a/routing/dht/dht_logger.go b/routing/dht/dht_logger.go index 1a0878bf7a4..0ff012956cb 100644 --- a/routing/dht/dht_logger.go +++ b/routing/dht/dht_logger.go @@ -2,6 +2,7 @@ package dht import ( "encoding/json" + "fmt" "time" ) @@ -29,12 +30,16 @@ func (l *logDhtRPC) EndLog() { func (l *logDhtRPC) Print() { b, err := json.Marshal(l) if err != nil { - log.Debug(err.Error()) + log.Debug("Error marshaling logDhtRPC object: %s", err) } else { log.Debug(string(b)) } } +func (l *logDhtRPC) String() string { + return fmt.Sprintf("DHT RPC: %s took %s, success = %s", l.Type, l.Duration, l.Success) +} + func (l *logDhtRPC) EndAndPrint() { l.EndLog() l.Print() diff --git a/routing/dht/dht_test.go b/routing/dht/dht_test.go index 1d7413fd53b..98b196da531 100644 --- a/routing/dht/dht_test.go +++ b/routing/dht/dht_test.go @@ -10,7 +10,6 @@ import ( ma "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr" ci "github.com/jbenet/go-ipfs/crypto" - spipe "github.com/jbenet/go-ipfs/crypto/spipe" inet "github.com/jbenet/go-ipfs/net" mux "github.com/jbenet/go-ipfs/net/mux" netservice "github.com/jbenet/go-ipfs/net/service" @@ -21,9 +20,7 @@ import ( "time" ) -func setupDHT(t *testing.T, p *peer.Peer) *IpfsDHT { - ctx := context.Background() - +func setupDHT(ctx context.Context, t *testing.T, p *peer.Peer) *IpfsDHT { peerstore := peer.NewPeerstore() dhts := netservice.NewService(nil) // nil handler for now, need to patch it @@ -38,12 +35,12 @@ func setupDHT(t *testing.T, p *peer.Peer) *IpfsDHT { t.Fatal(err) } - d := NewDHT(p, peerstore, net, dhts, ds.NewMapDatastore()) + d := NewDHT(ctx, p, peerstore, net, dhts, ds.NewMapDatastore()) dhts.SetHandler(d) return d } -func setupDHTS(n int, t *testing.T) ([]ma.Multiaddr, []*peer.Peer, []*IpfsDHT) { +func setupDHTS(ctx context.Context, n int, t *testing.T) ([]ma.Multiaddr, []*peer.Peer, []*IpfsDHT) { var addrs []ma.Multiaddr for i := 0; i < n; i++ { a, err := ma.NewMultiaddr(fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", 5000+i)) @@ -61,7 +58,7 @@ func setupDHTS(n int, t *testing.T) ([]ma.Multiaddr, []*peer.Peer, []*IpfsDHT) { dhts := make([]*IpfsDHT, n) for i := 0; i < n; i++ { - dhts[i] = setupDHT(t, peers[i]) + dhts[i] = setupDHT(ctx, t, peers[i]) } return addrs, peers, dhts @@ -76,7 +73,7 @@ func makePeer(addr ma.Multiaddr) *peer.Peer { } p.PrivKey = sk p.PubKey = pk - id, err := spipe.IDFromPubKey(pk) + id, err := peer.IDFromPubKey(pk) if err != nil { panic(err) } @@ -87,7 +84,7 @@ func makePeer(addr ma.Multiaddr) *peer.Peer { func TestPing(t *testing.T) { // t.Skip("skipping test to debug another") - + ctx := context.Background() u.Debug = false addrA, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/2222") if err != nil { @@ -101,28 +98,28 @@ func TestPing(t *testing.T) { peerA := makePeer(addrA) peerB := makePeer(addrB) - dhtA := setupDHT(t, peerA) - dhtB := setupDHT(t, peerB) + dhtA := setupDHT(ctx, t, peerA) + dhtB := setupDHT(ctx, t, peerB) defer dhtA.Halt() defer dhtB.Halt() defer dhtA.network.Close() defer dhtB.network.Close() - _, err = dhtA.Connect(context.Background(), peerB) + _, err = dhtA.Connect(ctx, peerB) if err != nil { t.Fatal(err) } //Test that we can ping the node - ctx, _ := context.WithTimeout(context.Background(), 5*time.Millisecond) - err = dhtA.Ping(ctx, peerB) + ctxT, _ := context.WithTimeout(ctx, 100*time.Millisecond) + err = dhtA.Ping(ctxT, peerB) if err != nil { t.Fatal(err) } - ctx, _ = context.WithTimeout(context.Background(), 5*time.Millisecond) - err = dhtB.Ping(ctx, peerA) + ctxT, _ = context.WithTimeout(ctx, 100*time.Millisecond) + err = dhtB.Ping(ctxT, peerA) if err != nil { t.Fatal(err) } @@ -131,12 +128,13 @@ func TestPing(t *testing.T) { func TestValueGetSet(t *testing.T) { // t.Skip("skipping test to debug another") + ctx := context.Background() u.Debug = false - addrA, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/1235") + addrA, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/11235") if err != nil { t.Fatal(err) } - addrB, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/5679") + addrB, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/15679") if err != nil { t.Fatal(err) } @@ -144,23 +142,23 @@ func TestValueGetSet(t *testing.T) { peerA := makePeer(addrA) peerB := makePeer(addrB) - dhtA := setupDHT(t, peerA) - dhtB := setupDHT(t, peerB) + dhtA := setupDHT(ctx, t, peerA) + dhtB := setupDHT(ctx, t, peerB) defer dhtA.Halt() defer dhtB.Halt() defer dhtA.network.Close() defer dhtB.network.Close() - _, err = dhtA.Connect(context.Background(), peerB) + _, err = dhtA.Connect(ctx, peerB) if err != nil { t.Fatal(err) } - ctxT, _ := context.WithTimeout(context.Background(), time.Second) + ctxT, _ := context.WithTimeout(ctx, time.Second) dhtA.PutValue(ctxT, "hello", []byte("world")) - ctxT, _ = context.WithTimeout(context.Background(), time.Second*2) + ctxT, _ = context.WithTimeout(ctx, time.Second*2) val, err := dhtA.GetValue(ctxT, "hello") if err != nil { t.Fatal(err) @@ -170,7 +168,7 @@ func TestValueGetSet(t *testing.T) { t.Fatalf("Expected 'world' got '%s'", string(val)) } - ctxT, _ = context.WithTimeout(context.Background(), time.Second*2) + ctxT, _ = context.WithTimeout(ctx, time.Second*2) val, err = dhtB.GetValue(ctxT, "hello") if err != nil { t.Fatal(err) @@ -183,10 +181,11 @@ func TestValueGetSet(t *testing.T) { func TestProvides(t *testing.T) { // t.Skip("skipping test to debug another") + ctx := context.Background() u.Debug = false - _, peers, dhts := setupDHTS(4, t) + _, peers, dhts := setupDHTS(ctx, 4, t) defer func() { for i := 0; i < 4; i++ { dhts[i].Halt() @@ -194,17 +193,17 @@ func TestProvides(t *testing.T) { } }() - _, err := dhts[0].Connect(context.Background(), peers[1]) + _, err := dhts[0].Connect(ctx, peers[1]) if err != nil { t.Fatal(err) } - _, err = dhts[1].Connect(context.Background(), peers[2]) + _, err = dhts[1].Connect(ctx, peers[2]) if err != nil { t.Fatal(err) } - _, err = dhts[1].Connect(context.Background(), peers[3]) + _, err = dhts[1].Connect(ctx, peers[3]) if err != nil { t.Fatal(err) } @@ -219,30 +218,34 @@ func TestProvides(t *testing.T) { t.Fatal(err) } - err = dhts[3].Provide(context.Background(), u.Key("hello")) + err = dhts[3].Provide(ctx, u.Key("hello")) if err != nil { t.Fatal(err) } time.Sleep(time.Millisecond * 60) - ctxT, _ := context.WithTimeout(context.Background(), time.Second) - provs, err := dhts[0].FindProviders(ctxT, u.Key("hello")) - if err != nil { - t.Fatal(err) - } + ctxT, _ := context.WithTimeout(ctx, time.Second) + provchan := dhts[0].FindProvidersAsync(ctxT, u.Key("hello"), 1) - if len(provs) != 1 { - t.Fatal("Didnt get back providers") + after := time.After(time.Second) + select { + case prov := <-provchan: + if prov == nil { + t.Fatal("Got back nil provider") + } + case <-after: + t.Fatal("Did not get a provider back.") } } func TestProvidesAsync(t *testing.T) { // t.Skip("skipping test to debug another") + ctx := context.Background() u.Debug = false - _, peers, dhts := setupDHTS(4, t) + _, peers, dhts := setupDHTS(ctx, 4, t) defer func() { for i := 0; i < 4; i++ { dhts[i].Halt() @@ -250,17 +253,17 @@ func TestProvidesAsync(t *testing.T) { } }() - _, err := dhts[0].Connect(context.Background(), peers[1]) + _, err := dhts[0].Connect(ctx, peers[1]) if err != nil { t.Fatal(err) } - _, err = dhts[1].Connect(context.Background(), peers[2]) + _, err = dhts[1].Connect(ctx, peers[2]) if err != nil { t.Fatal(err) } - _, err = dhts[1].Connect(context.Background(), peers[3]) + _, err = dhts[1].Connect(ctx, peers[3]) if err != nil { t.Fatal(err) } @@ -275,21 +278,21 @@ func TestProvidesAsync(t *testing.T) { t.Fatal(err) } - err = dhts[3].Provide(context.Background(), u.Key("hello")) + err = dhts[3].Provide(ctx, u.Key("hello")) if err != nil { t.Fatal(err) } time.Sleep(time.Millisecond * 60) - ctx, _ := context.WithTimeout(context.TODO(), time.Millisecond*300) - provs := dhts[0].FindProvidersAsync(ctx, u.Key("hello"), 5) + ctxT, _ := context.WithTimeout(ctx, time.Millisecond*300) + provs := dhts[0].FindProvidersAsync(ctxT, u.Key("hello"), 5) select { case p := <-provs: if !p.ID.Equal(dhts[3].self.ID) { t.Fatalf("got a provider, but not the right one. %s", p) } - case <-ctx.Done(): + case <-ctxT.Done(): t.Fatal("Didnt get back providers") } } @@ -297,8 +300,9 @@ func TestProvidesAsync(t *testing.T) { func TestLayeredGet(t *testing.T) { // t.Skip("skipping test to debug another") + ctx := context.Background() u.Debug = false - _, peers, dhts := setupDHTS(4, t) + _, peers, dhts := setupDHTS(ctx, 4, t) defer func() { for i := 0; i < 4; i++ { dhts[i].Halt() @@ -306,17 +310,17 @@ func TestLayeredGet(t *testing.T) { } }() - _, err := dhts[0].Connect(context.Background(), peers[1]) + _, err := dhts[0].Connect(ctx, peers[1]) if err != nil { t.Fatalf("Failed to connect: %s", err) } - _, err = dhts[1].Connect(context.Background(), peers[2]) + _, err = dhts[1].Connect(ctx, peers[2]) if err != nil { t.Fatal(err) } - _, err = dhts[1].Connect(context.Background(), peers[3]) + _, err = dhts[1].Connect(ctx, peers[3]) if err != nil { t.Fatal(err) } @@ -326,14 +330,14 @@ func TestLayeredGet(t *testing.T) { t.Fatal(err) } - err = dhts[3].Provide(context.Background(), u.Key("hello")) + err = dhts[3].Provide(ctx, u.Key("hello")) if err != nil { t.Fatal(err) } time.Sleep(time.Millisecond * 60) - ctxT, _ := context.WithTimeout(context.Background(), time.Second) + ctxT, _ := context.WithTimeout(ctx, time.Second) val, err := dhts[0].GetValue(ctxT, u.Key("hello")) if err != nil { t.Fatal(err) @@ -348,9 +352,10 @@ func TestLayeredGet(t *testing.T) { func TestFindPeer(t *testing.T) { // t.Skip("skipping test to debug another") + ctx := context.Background() u.Debug = false - _, peers, dhts := setupDHTS(4, t) + _, peers, dhts := setupDHTS(ctx, 4, t) defer func() { for i := 0; i < 4; i++ { dhts[i].Halt() @@ -358,22 +363,22 @@ func TestFindPeer(t *testing.T) { } }() - _, err := dhts[0].Connect(context.Background(), peers[1]) + _, err := dhts[0].Connect(ctx, peers[1]) if err != nil { t.Fatal(err) } - _, err = dhts[1].Connect(context.Background(), peers[2]) + _, err = dhts[1].Connect(ctx, peers[2]) if err != nil { t.Fatal(err) } - _, err = dhts[1].Connect(context.Background(), peers[3]) + _, err = dhts[1].Connect(ctx, peers[3]) if err != nil { t.Fatal(err) } - ctxT, _ := context.WithTimeout(context.Background(), time.Second) + ctxT, _ := context.WithTimeout(ctx, time.Second) p, err := dhts[0].FindPeer(ctxT, peers[2].ID) if err != nil { t.Fatal(err) @@ -387,3 +392,65 @@ func TestFindPeer(t *testing.T) { t.Fatal("Didnt find expected peer.") } } + +func TestConnectCollision(t *testing.T) { + // t.Skip("skipping test to debug another") + + runTimes := 10 + + for rtime := 0; rtime < runTimes; rtime++ { + log.Notice("Running Time: ", rtime) + + ctx := context.Background() + u.Debug = false + addrA, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/11235") + if err != nil { + t.Fatal(err) + } + addrB, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/15679") + if err != nil { + t.Fatal(err) + } + + peerA := makePeer(addrA) + peerB := makePeer(addrB) + + dhtA := setupDHT(ctx, t, peerA) + dhtB := setupDHT(ctx, t, peerB) + + done := make(chan struct{}) + go func() { + _, err = dhtA.Connect(ctx, peerB) + if err != nil { + t.Fatal(err) + } + done <- struct{}{} + }() + go func() { + _, err = dhtB.Connect(ctx, peerA) + if err != nil { + t.Fatal(err) + } + done <- struct{}{} + }() + + timeout := time.After(time.Second) + select { + case <-done: + case <-timeout: + t.Fatal("Timeout received!") + } + select { + case <-done: + case <-timeout: + t.Fatal("Timeout received!") + } + + dhtA.Halt() + dhtB.Halt() + dhtA.network.Close() + dhtB.network.Close() + + <-time.After(200 * time.Millisecond) + } +} diff --git a/routing/dht/ext_test.go b/routing/dht/ext_test.go index 88f51237859..b5bf4877228 100644 --- a/routing/dht/ext_test.go +++ b/routing/dht/ext_test.go @@ -92,6 +92,10 @@ func (f *fauxNet) GetPeerList() []*peer.Peer { return nil } +func (f *fauxNet) GetBandwidthTotals() (uint64, uint64) { + return 0, 0 +} + // Close terminates all network operation func (f *fauxNet) Close() error { return nil } @@ -106,7 +110,7 @@ func TestGetFailures(t *testing.T) { local := new(peer.Peer) local.ID = peer.ID("test_peer") - d := NewDHT(local, peerstore, fn, fs, ds.NewMapDatastore()) + d := NewDHT(ctx, local, peerstore, fn, fs, ds.NewMapDatastore()) other := &peer.Peer{ID: peer.ID("other_peer")} d.Update(other) @@ -196,6 +200,7 @@ func _randPeer() *peer.Peer { func TestNotFound(t *testing.T) { // t.Skip("skipping test because it makes a lot of output") + ctx := context.Background() fn := &fauxNet{} fs := &fauxSender{} @@ -203,7 +208,7 @@ func TestNotFound(t *testing.T) { local.ID = peer.ID("test_peer") peerstore := peer.NewPeerstore() - d := NewDHT(local, peerstore, fn, fs, ds.NewMapDatastore()) + d := NewDHT(ctx, local, peerstore, fn, fs, ds.NewMapDatastore()) var ps []*peer.Peer for i := 0; i < 5; i++ { @@ -239,7 +244,7 @@ func TestNotFound(t *testing.T) { }) - ctx, _ := context.WithTimeout(context.Background(), time.Second*5) + ctx, _ = context.WithTimeout(ctx, time.Second*5) v, err := d.GetValue(ctx, u.Key("hello")) log.Debug("get value got %v", v) if err != nil { @@ -261,6 +266,7 @@ func TestNotFound(t *testing.T) { func TestLessThanKResponses(t *testing.T) { // t.Skip("skipping test because it makes a lot of output") + ctx := context.Background() u.Debug = false fn := &fauxNet{} fs := &fauxSender{} @@ -268,7 +274,7 @@ func TestLessThanKResponses(t *testing.T) { local := new(peer.Peer) local.ID = peer.ID("test_peer") - d := NewDHT(local, peerstore, fn, fs, ds.NewMapDatastore()) + d := NewDHT(ctx, local, peerstore, fn, fs, ds.NewMapDatastore()) var ps []*peer.Peer for i := 0; i < 5; i++ { @@ -303,7 +309,7 @@ func TestLessThanKResponses(t *testing.T) { }) - ctx, _ := context.WithTimeout(context.Background(), time.Second*30) + ctx, _ = context.WithTimeout(ctx, time.Second*30) _, err := d.GetValue(ctx, u.Key("hello")) if err != nil { switch err { diff --git a/routing/dht/handlers.go b/routing/dht/handlers.go index 417dd0918f1..1046516b66d 100644 --- a/routing/dht/handlers.go +++ b/routing/dht/handlers.go @@ -5,14 +5,14 @@ import ( "fmt" "time" - msg "github.com/jbenet/go-ipfs/net/message" peer "github.com/jbenet/go-ipfs/peer" - kb "github.com/jbenet/go-ipfs/routing/kbucket" u "github.com/jbenet/go-ipfs/util" ds "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/datastore.go" ) +var CloserPeerCount = 4 + // dhthandler specifies the signature of functions that handle DHT messages. type dhtHandler func(*peer.Peer, *Message) (*Message, error) @@ -30,8 +30,6 @@ func (dht *IpfsDHT) handlerForMsgType(t Message_MessageType) dhtHandler { return dht.handleGetProviders case Message_PING: return dht.handlePing - case Message_DIAGNOSTIC: - return dht.handleDiagnostic default: return nil } @@ -83,10 +81,12 @@ func (dht *IpfsDHT) handleGetValue(p *peer.Peer, pmes *Message) (*Message, error } // Find closest peer on given cluster to desired key and reply with that info - closer := dht.betterPeerToQuery(pmes) + closer := dht.betterPeersToQuery(pmes, CloserPeerCount) if closer != nil { - log.Debug("handleGetValue returning a closer peer: '%s'\n", closer) - resp.CloserPeers = peersToPBPeers([]*peer.Peer{closer}) + for _, p := range closer { + log.Debug("handleGetValue returning closer peer: '%s'", p) + } + resp.CloserPeers = peersToPBPeers(closer) } return resp, nil @@ -109,27 +109,31 @@ func (dht *IpfsDHT) handlePing(p *peer.Peer, pmes *Message) (*Message, error) { func (dht *IpfsDHT) handleFindPeer(p *peer.Peer, pmes *Message) (*Message, error) { resp := newMessage(pmes.GetType(), "", pmes.GetClusterLevel()) - var closest *peer.Peer + var closest []*peer.Peer // if looking for self... special case where we send it on CloserPeers. if peer.ID(pmes.GetKey()).Equal(dht.self.ID) { - closest = dht.self + closest = []*peer.Peer{dht.self} } else { - closest = dht.betterPeerToQuery(pmes) + closest = dht.betterPeersToQuery(pmes, CloserPeerCount) } if closest == nil { - log.Error("handleFindPeer: could not find anything.\n") + log.Error("handleFindPeer: could not find anything.") return resp, nil } - if len(closest.Addresses) == 0 { - log.Error("handleFindPeer: no addresses for connected peer...\n") - return resp, nil + var withAddresses []*peer.Peer + for _, p := range closest { + if len(p.Addresses) > 0 { + withAddresses = append(withAddresses, p) + } } - log.Debug("handleFindPeer: sending back '%s'\n", closest) - resp.CloserPeers = peersToPBPeers([]*peer.Peer{closest}) + for _, p := range withAddresses { + log.Debug("handleFindPeer: sending back '%s'", p) + } + resp.CloserPeers = peersToPBPeers(withAddresses) return resp, nil } @@ -157,9 +161,9 @@ func (dht *IpfsDHT) handleGetProviders(p *peer.Peer, pmes *Message) (*Message, e } // Also send closer peers. - closer := dht.betterPeerToQuery(pmes) + closer := dht.betterPeersToQuery(pmes, CloserPeerCount) if closer != nil { - resp.CloserPeers = peersToPBPeers([]*peer.Peer{closer}) + resp.CloserPeers = peersToPBPeers(closer) } return resp, nil @@ -175,7 +179,26 @@ func (dht *IpfsDHT) handleAddProvider(p *peer.Peer, pmes *Message) (*Message, er log.Debug("%s adding %s as a provider for '%s'\n", dht.self, p, peer.ID(key)) - dht.providers.AddProvider(key, p) + // add provider should use the address given in the message + for _, pb := range pmes.GetProviderPeers() { + pid := peer.ID(pb.GetId()) + if pid.Equal(p.ID) { + + addr, err := pb.Address() + if err != nil { + log.Error("provider %s error with address %s", p, *pb.Addr) + continue + } + + log.Info("received provider %s %s for %s", p, addr, key) + p.AddAddress(addr) + dht.providers.AddProvider(key, p) + + } else { + log.Error("handleAddProvider received provider %s from %s", pid, p) + } + } + return pmes, nil // send back same msg as confirmation. } @@ -184,53 +207,3 @@ func (dht *IpfsDHT) handleAddProvider(p *peer.Peer, pmes *Message) (*Message, er func (dht *IpfsDHT) Halt() { dht.providers.Halt() } - -// NOTE: not yet finished, low priority -func (dht *IpfsDHT) handleDiagnostic(p *peer.Peer, pmes *Message) (*Message, error) { - seq := dht.routingTables[0].NearestPeers(kb.ConvertPeerID(dht.self.ID), 10) - - for _, ps := range seq { - _, err := msg.FromObject(ps, pmes) - if err != nil { - log.Error("handleDiagnostics error creating message: %v\n", err) - continue - } - // dht.sender.SendRequest(context.TODO(), mes) - } - return nil, errors.New("not yet ported back") - - // buf := new(bytes.Buffer) - // di := dht.getDiagInfo() - // buf.Write(di.Marshal()) - // - // // NOTE: this shouldnt be a hardcoded value - // after := time.After(time.Second * 20) - // count := len(seq) - // for count > 0 { - // select { - // case <-after: - // //Timeout, return what we have - // goto out - // case reqResp := <-listenChan: - // pmesOut := new(Message) - // err := proto.Unmarshal(reqResp.Data, pmesOut) - // if err != nil { - // // It broke? eh, whatever, keep going - // continue - // } - // buf.Write(reqResp.Data) - // count-- - // } - // } - // - // out: - // resp := Message{ - // Type: Message_DIAGNOSTIC, - // ID: pmes.GetId(), - // Value: buf.Bytes(), - // Response: true, - // } - // - // mes := swarm.NewMessage(p, resp.ToProtobuf()) - // dht.netChan.Outgoing <- mes -} diff --git a/routing/dht/messages.pb.go b/routing/dht/messages.pb.go index b6e9fa4f26c..2da77e7bc20 100644 --- a/routing/dht/messages.pb.go +++ b/routing/dht/messages.pb.go @@ -1,4 +1,4 @@ -// Code generated by protoc-gen-gogo. +// Code generated by protoc-gen-go. // source: messages.proto // DO NOT EDIT! @@ -13,13 +13,11 @@ It has these top-level messages: */ package dht -import proto "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/gogoprotobuf/proto" -import json "encoding/json" +import proto "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto" import math "math" -// Reference proto, json, and math imports to suppress error if they are not otherwise used. +// Reference imports to suppress errors if they are not otherwise used. var _ = proto.Marshal -var _ = &json.SyntaxError{} var _ = math.Inf type Message_MessageType int32 @@ -31,7 +29,6 @@ const ( Message_GET_PROVIDERS Message_MessageType = 3 Message_FIND_NODE Message_MessageType = 4 Message_PING Message_MessageType = 5 - Message_DIAGNOSTIC Message_MessageType = 6 ) var Message_MessageType_name = map[int32]string{ @@ -41,7 +38,6 @@ var Message_MessageType_name = map[int32]string{ 3: "GET_PROVIDERS", 4: "FIND_NODE", 5: "PING", - 6: "DIAGNOSTIC", } var Message_MessageType_value = map[string]int32{ "PUT_VALUE": 0, @@ -50,7 +46,6 @@ var Message_MessageType_value = map[string]int32{ "GET_PROVIDERS": 3, "FIND_NODE": 4, "PING": 5, - "DIAGNOSTIC": 6, } func (x Message_MessageType) Enum() *Message_MessageType { @@ -72,7 +67,7 @@ func (x *Message_MessageType) UnmarshalJSON(data []byte) error { type Message struct { // defines what type of message it is. - Type *Message_MessageType `protobuf:"varint,1,req,name=type,enum=dht.Message_MessageType" json:"type,omitempty"` + Type *Message_MessageType `protobuf:"varint,1,opt,name=type,enum=dht.Message_MessageType" json:"type,omitempty"` // defines what coral cluster level this query/response belongs to. ClusterLevelRaw *int32 `protobuf:"varint,10,opt,name=clusterLevelRaw" json:"clusterLevelRaw,omitempty"` // Used to specify the key associated with this message. @@ -137,8 +132,8 @@ func (m *Message) GetProviderPeers() []*Message_Peer { } type Message_Peer struct { - Id *string `protobuf:"bytes,1,req,name=id" json:"id,omitempty"` - Addr *string `protobuf:"bytes,2,req,name=addr" json:"addr,omitempty"` + Id *string `protobuf:"bytes,1,opt,name=id" json:"id,omitempty"` + Addr *string `protobuf:"bytes,2,opt,name=addr" json:"addr,omitempty"` XXX_unrecognized []byte `json:"-"` } diff --git a/routing/dht/messages.proto b/routing/dht/messages.proto index 3c33f9382b2..0676901504c 100644 --- a/routing/dht/messages.proto +++ b/routing/dht/messages.proto @@ -10,16 +10,15 @@ message Message { GET_PROVIDERS = 3; FIND_NODE = 4; PING = 5; - DIAGNOSTIC = 6; } message Peer { - required string id = 1; - required string addr = 2; + optional string id = 1; + optional string addr = 2; } // defines what type of message it is. - required MessageType type = 1; + optional MessageType type = 1; // defines what coral cluster level this query/response belongs to. optional int32 clusterLevelRaw = 10; diff --git a/routing/dht/routing.go b/routing/dht/routing.go index c14031ce2b5..55ef265cb7b 100644 --- a/routing/dht/routing.go +++ b/routing/dht/routing.go @@ -1,8 +1,7 @@ package dht import ( - "bytes" - "encoding/json" + "sync" context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" @@ -61,6 +60,7 @@ func (dht *IpfsDHT) GetValue(ctx context.Context, key u.Key) ([]byte, error) { routeLevel := 0 closest := dht.routingTables[routeLevel].NearestPeers(kb.ConvertKey(key), PoolSize) if closest == nil || len(closest) == 0 { + log.Warning("Got no peers back from routing table!") return nil, nil } @@ -117,26 +117,7 @@ func (dht *IpfsDHT) Provide(ctx context.Context, key u.Key) error { return nil } -// NB: not actually async. Used to keep the interface consistent while the -// actual async method, FindProvidersAsync2 is under construction func (dht *IpfsDHT) FindProvidersAsync(ctx context.Context, key u.Key, count int) <-chan *peer.Peer { - ch := make(chan *peer.Peer) - providers, err := dht.FindProviders(ctx, key) - if err != nil { - close(ch) - return ch - } - go func() { - defer close(ch) - for _, p := range providers { - ch <- p - } - }() - return ch -} - -// FIXME: there's a bug here! -func (dht *IpfsDHT) FindProvidersAsync2(ctx context.Context, key u.Key, count int) <-chan *peer.Peer { peerOut := make(chan *peer.Peer, count) go func() { ps := newPeerSet() @@ -151,9 +132,12 @@ func (dht *IpfsDHT) FindProvidersAsync2(ctx context.Context, key u.Key, count in } } + wg := new(sync.WaitGroup) peers := dht.routingTables[0].NearestPeers(kb.ConvertKey(key), AlphaValue) for _, pp := range peers { + wg.Add(1) go func(p *peer.Peer) { + defer wg.Done() pmes, err := dht.findProvidersSingle(ctx, p, key, 0) if err != nil { log.Error("%s", err) @@ -162,7 +146,8 @@ func (dht *IpfsDHT) FindProvidersAsync2(ctx context.Context, key u.Key, count in dht.addPeerListAsync(key, pmes.GetProviderPeers(), ps, count, peerOut) }(pp) } - + wg.Wait() + close(peerOut) }() return peerOut } @@ -186,61 +171,16 @@ func (dht *IpfsDHT) addPeerListAsync(k u.Key, peers []*Message_Peer, ps *peerSet } } -// FindProviders searches for peers who can provide the value for given key. -func (dht *IpfsDHT) FindProviders(ctx context.Context, key u.Key) ([]*peer.Peer, error) { - // get closest peer - log.Debug("Find providers for: '%s'", key) - p := dht.routingTables[0].NearestPeer(kb.ConvertKey(key)) - if p == nil { - log.Warning("Got no nearest peer for find providers: '%s'", key) - return nil, nil - } - - for level := 0; level < len(dht.routingTables); { - - // attempt retrieving providers - pmes, err := dht.findProvidersSingle(ctx, p, key, level) - if err != nil { - return nil, err - } - - // handle providers - provs := pmes.GetProviderPeers() - if provs != nil { - log.Debug("Got providers back from findProviders call!") - return dht.addProviders(key, provs), nil - } - - log.Debug("Didnt get providers, just closer peers.") - closer := pmes.GetCloserPeers() - if len(closer) == 0 { - level++ - continue - } - - np, err := dht.peerFromInfo(closer[0]) - if err != nil { - log.Debug("no peerFromInfo") - level++ - continue - } - p = np - } - return nil, u.ErrNotFound -} - // Find specific Peer - // FindPeer searches for a peer with given ID. func (dht *IpfsDHT) FindPeer(ctx context.Context, id peer.ID) (*peer.Peer, error) { // Check if were already connected to them - p, _ := dht.Find(id) + p, _ := dht.FindLocal(id) if p != nil { return p, nil } - // @whyrusleeping why is this here? doesn't the dht.Find above cover it? routeLevel := 0 p = dht.routingTables[routeLevel].NearestPeer(kb.ConvertPeerID(id)) if p == nil { @@ -277,7 +217,7 @@ func (dht *IpfsDHT) FindPeer(ctx context.Context, id peer.ID) (*peer.Peer, error func (dht *IpfsDHT) findPeerMultiple(ctx context.Context, id peer.ID) (*peer.Peer, error) { // Check if were already connected to them - p, _ := dht.Find(id) + p, _ := dht.FindLocal(id) if p != nil { return p, nil } @@ -341,33 +281,3 @@ func (dht *IpfsDHT) Ping(ctx context.Context, p *peer.Peer) error { log.Info("ping %s end (err = %s)", p, err) return err } - -func (dht *IpfsDHT) getDiagnostic(ctx context.Context) ([]*diagInfo, error) { - - log.Info("Begin Diagnostic") - peers := dht.routingTables[0].NearestPeers(kb.ConvertPeerID(dht.self.ID), 10) - var out []*diagInfo - - query := newQuery(dht.self.Key(), func(ctx context.Context, p *peer.Peer) (*dhtQueryResult, error) { - pmes := newMessage(Message_DIAGNOSTIC, "", 0) - rpmes, err := dht.sendRequest(ctx, p, pmes) - if err != nil { - return nil, err - } - - dec := json.NewDecoder(bytes.NewBuffer(rpmes.GetValue())) - for { - di := new(diagInfo) - err := dec.Decode(di) - if err != nil { - break - } - - out = append(out, di) - } - return &dhtQueryResult{success: true}, nil - }) - - _, err := query.Run(ctx, peers) - return out, err -} diff --git a/routing/kbucket/util.go b/routing/kbucket/util.go index 3aca06f6afe..02994230a62 100644 --- a/routing/kbucket/util.go +++ b/routing/kbucket/util.go @@ -31,11 +31,11 @@ func (id ID) less(other ID) bool { } func xor(a, b ID) ID { - return ID(ks.XOR(a, b)) + return ID(u.XOR(a, b)) } func commonPrefixLen(a, b ID) int { - return ks.ZeroPrefixLen(ks.XOR(a, b)) + return ks.ZeroPrefixLen(u.XOR(a, b)) } // ConvertPeerID creates a DHT ID by hashing a Peer ID (Multihash) diff --git a/routing/keyspace/xor.go b/routing/keyspace/xor.go index dbb7c68516a..7159f2cadd7 100644 --- a/routing/keyspace/xor.go +++ b/routing/keyspace/xor.go @@ -4,6 +4,8 @@ import ( "bytes" "crypto/sha256" "math/big" + + u "github.com/jbenet/go-ipfs/util" ) // XORKeySpace is a KeySpace which: @@ -33,7 +35,7 @@ func (s *xorKeySpace) Equal(k1, k2 Key) bool { // Distance returns the distance metric in this key space func (s *xorKeySpace) Distance(k1, k2 Key) *big.Int { // XOR the keys - k3 := XOR(k1.Bytes, k2.Bytes) + k3 := u.XOR(k1.Bytes, k2.Bytes) // interpret it as an integer dist := big.NewInt(0).SetBytes(k3) @@ -52,15 +54,6 @@ func (s *xorKeySpace) Less(k1, k2 Key) bool { return true } -// XOR takes two byte slices, XORs them together, returns the resulting slice. -func XOR(a, b []byte) []byte { - c := make([]byte, len(a)) - for i := 0; i < len(a); i++ { - c[i] = a[i] ^ b[i] - } - return c -} - // ZeroPrefixLen returns the number of consecutive zeroes in a byte slice. func ZeroPrefixLen(id []byte) int { for i := 0; i < len(id); i++ { diff --git a/routing/keyspace/xor_test.go b/routing/keyspace/xor_test.go index 7963ea014a8..8db4b926cdb 100644 --- a/routing/keyspace/xor_test.go +++ b/routing/keyspace/xor_test.go @@ -4,34 +4,9 @@ import ( "bytes" "math/big" "testing" -) - -func TestXOR(t *testing.T) { - cases := [][3][]byte{ - [3][]byte{ - []byte{0xFF, 0xFF, 0xFF}, - []byte{0xFF, 0xFF, 0xFF}, - []byte{0x00, 0x00, 0x00}, - }, - [3][]byte{ - []byte{0x00, 0xFF, 0x00}, - []byte{0xFF, 0xFF, 0xFF}, - []byte{0xFF, 0x00, 0xFF}, - }, - [3][]byte{ - []byte{0x55, 0x55, 0x55}, - []byte{0x55, 0xFF, 0xAA}, - []byte{0x00, 0xAA, 0xFF}, - }, - } - for _, c := range cases { - r := XOR(c[0], c[1]) - if !bytes.Equal(r, c[2]) { - t.Error("XOR failed") - } - } -} + u "github.com/jbenet/go-ipfs/util" +) func TestPrefixLen(t *testing.T) { cases := [][]byte{ @@ -126,7 +101,7 @@ func TestDistancesAndCenterSorting(t *testing.T) { } d1 := keys[2].Distance(keys[5]) - d2 := XOR(keys[2].Bytes, keys[5].Bytes) + d2 := u.XOR(keys[2].Bytes, keys[5].Bytes) d2 = d2[len(keys[2].Bytes)-len(d1.Bytes()):] // skip empty space for big if !bytes.Equal(d1.Bytes(), d2) { t.Errorf("bytes should be the same. %v == %v", d1.Bytes(), d2) diff --git a/routing/routing.go b/routing/routing.go index 4669fb48c7b..f3dd0c9d86a 100644 --- a/routing/routing.go +++ b/routing/routing.go @@ -26,11 +26,7 @@ type IpfsRouting interface { // Announce that this node can provide value for given key Provide(context.Context, u.Key) error - // FindProviders searches for peers who can provide the value for given key. - FindProviders(context.Context, u.Key) ([]*peer.Peer, error) - // Find specific Peer - // FindPeer searches for a peer with given ID. FindPeer(context.Context, peer.ID) (*peer.Peer, error) } diff --git a/util/ctxcloser/closer.go b/util/ctxcloser/closer.go new file mode 100644 index 00000000000..e04178c2444 --- /dev/null +++ b/util/ctxcloser/closer.go @@ -0,0 +1,154 @@ +package ctxcloser + +import ( + "sync" + + context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" +) + +// CloseFunc is a function used to close a ContextCloser +type CloseFunc func() error + +// ContextCloser is an interface for services able to be opened and closed. +// It has a parent Context, and Children. But ContextCloser is not a proper +// "tree" like the Context tree. It is more like a Context-WaitGroup hybrid. +// It models a main object with a few children objects -- and, unlike the +// context -- concerns itself with the parent-child closing semantics: +// +// - Can define a CloseFunc (func() error) to be run at Close time. +// - Children call Children().Add(1) to be waited upon +// - Children can select on <-Closing() to know when they should shut down. +// - Close() will wait until all children call Children().Done() +// - <-Closed() signals when the service is completely closed. +// +// ContextCloser can be embedded into the main object itself. In that case, +// the closeFunc (if a member function) has to be set after the struct +// is intialized: +// +// type service struct { +// ContextCloser +// net.Conn +// } +// +// func (s *service) close() error { +// return s.Conn.Close() +// } +// +// func newService(ctx context.Context, c net.Conn) *service { +// s := &service{c} +// s.ContextCloser = NewContextCloser(ctx, s.close) +// return s +// } +// +type ContextCloser interface { + + // Context is the context of this ContextCloser. It is "sort of" a parent. + Context() context.Context + + // Children is a sync.Waitgroup for all children goroutines that should + // shut down completely before this service is said to be "closed". + // Follows the semantics of WaitGroup: + // Children().Add(1) // add one more dependent child + // Children().Done() // child signals it is done + Children() *sync.WaitGroup + + // Close is a method to call when you wish to stop this ContextCloser + Close() error + + // Closing is a signal to wait upon, like Context.Done(). + // It fires when the object should be closing (but hasn't yet fully closed). + // The primary use case is for child goroutines who need to know when + // they should shut down. (equivalent to Context().Done()) + Closing() <-chan struct{} + + // Closed is a method to wait upon, like Context.Done(). + // It fires when the entire object is fully closed. + // The primary use case is for external listeners who need to know when + // this object is completly done, and all its children closed. + Closed() <-chan struct{} +} + +// contextCloser is an OpenCloser with a cancellable context +type contextCloser struct { + ctx context.Context + cancel context.CancelFunc + + // called to run the close logic. + closeFunc CloseFunc + + // closed is released once the close function is done. + closed chan struct{} + + // wait group for child goroutines + children sync.WaitGroup + + // sync primitive to ensure the close logic is only called once. + closeOnce sync.Once + + // error to return to clients of Close(). + closeErr error +} + +// NewContextCloser constructs and returns a ContextCloser. It will call +// cf CloseFunc before its Done() Wait signals fire. +func NewContextCloser(ctx context.Context, cf CloseFunc) ContextCloser { + ctx, cancel := context.WithCancel(ctx) + c := &contextCloser{ + ctx: ctx, + cancel: cancel, + closeFunc: cf, + closed: make(chan struct{}), + } + + go c.closeOnContextDone() + return c +} + +func (c *contextCloser) Context() context.Context { + return c.ctx +} + +func (c *contextCloser) Children() *sync.WaitGroup { + return &c.children +} + +// Close is the external close function. it's a wrapper around internalClose +// that waits on Closed() +func (c *contextCloser) Close() error { + c.internalClose() + <-c.Closed() // wait until we're totally done. + return c.closeErr +} + +func (c *contextCloser) Closing() <-chan struct{} { + return c.Context().Done() +} + +func (c *contextCloser) Closed() <-chan struct{} { + return c.closed +} + +func (c *contextCloser) internalClose() { + go c.closeOnce.Do(c.closeLogic) +} + +// the _actual_ close process. +func (c *contextCloser) closeLogic() { + // this function should only be called once (hence the sync.Once). + // and it will panic at the bottom (on close(c.closed)) otherwise. + + c.cancel() // signal that we're shutting down (Closing) + c.closeErr = c.closeFunc() // actually run the close logic + c.children.Wait() // wait till all children are done. + close(c.closed) // signal that we're shut down (Closed) +} + +// if parent context is shut down before we call Close explicitly, +// we need to go through the Close motions anyway. Hence all the sync +// stuff all over the place... +func (c *contextCloser) closeOnContextDone() { + c.Children().Add(1) // we're a child goroutine, to be waited upon. + <-c.Context().Done() // wait until parent (context) is done. + c.internalClose() + c.Children().Done() +} diff --git a/util/key.go b/util/key.go index e7b5246c917..abcbf63291d 100644 --- a/util/key.go +++ b/util/key.go @@ -71,3 +71,12 @@ func IsValidHash(s string) bool { } return true } + +// XOR takes two byte slices, XORs them together, returns the resulting slice. +func XOR(a, b []byte) []byte { + c := make([]byte, len(a)) + for i := 0; i < len(a); i++ { + c[i] = a[i] ^ b[i] + } + return c +} diff --git a/util/log.go b/util/log.go index 6a66024de19..ae633de8374 100644 --- a/util/log.go +++ b/util/log.go @@ -13,8 +13,19 @@ func init() { var log = Logger("util") -// LogFormat is the format used for our logger. -var LogFormat = "%{color}%{time:2006-01-02 15:04:05.999999} %{shortfile} %{level}: %{color:reset}%{message}" +var ansiGray = "\033[0;37m" + +// LogFormats is a map of formats used for our logger, keyed by name. +var LogFormats = map[string]string{ + "default": "%{color}%{time:2006-01-02 15:04:05.999999} %{level} %{shortfile}: %{color:reset}%{message}", + "color": ansiGray + "%{time:15:04:05.999} %{color}%{level}: %{color:reset}%{message} " + ansiGray + "%{shortfile}%{color:reset}", +} + +// Logging environment variables +const ( + envLogging = "IPFS_LOGGING" + envLoggingFmt = "IPFS_LOGGING_FMT" +) // loggers is the set of loggers in the system var loggers = map[string]*logging.Logger{} @@ -26,13 +37,19 @@ func POut(format string, a ...interface{}) { // SetupLogging will initialize the logger backend and set the flags. func SetupLogging() { + + fmt := LogFormats[os.Getenv(envLoggingFmt)] + if fmt == "" { + fmt = LogFormats["default"] + } + backend := logging.NewLogBackend(os.Stderr, "", 0) logging.SetBackend(backend) - logging.SetFormatter(logging.MustStringFormatter(LogFormat)) + logging.SetFormatter(logging.MustStringFormatter(fmt)) lvl := logging.ERROR - if logenv := os.Getenv("IPFS_LOGGING"); logenv != "" { + if logenv := os.Getenv(envLogging); logenv != "" { var err error lvl, err = logging.LogLevel(logenv) if err != nil { diff --git a/util/util_test.go b/util/util_test.go index a85c492feeb..c2bb8f484e1 100644 --- a/util/util_test.go +++ b/util/util_test.go @@ -58,3 +58,30 @@ func TestByteChanReader(t *testing.T) { t.Fatal("Reader failed to stream correct bytes") } } + +func TestXOR(t *testing.T) { + cases := [][3][]byte{ + [3][]byte{ + []byte{0xFF, 0xFF, 0xFF}, + []byte{0xFF, 0xFF, 0xFF}, + []byte{0x00, 0x00, 0x00}, + }, + [3][]byte{ + []byte{0x00, 0xFF, 0x00}, + []byte{0xFF, 0xFF, 0xFF}, + []byte{0xFF, 0x00, 0xFF}, + }, + [3][]byte{ + []byte{0x55, 0x55, 0x55}, + []byte{0x55, 0xFF, 0xAA}, + []byte{0x00, 0xAA, 0xFF}, + }, + } + + for _, c := range cases { + r := XOR(c[0], c[1]) + if !bytes.Equal(r, c[2]) { + t.Error("XOR failed") + } + } +}