From dae5362dd298e2090ea7ec7423145ebc58f3f34e Mon Sep 17 00:00:00 2001 From: Sanjay Ghemawat Date: Thu, 3 Aug 2023 16:29:14 -0700 Subject: [PATCH] Avoid using connections until they are healthy. (#498) * Avoid using connections until they are healthy. * Change Balancer to track the set of healthy connections. * Added state machine to clientConnection. * Create clientConnection as soon as a resolver returns an address. * Do version handshake on connection before adding it to the balancer. * Removed CallOptions.Balancer (call.Connection now has one balancer). * Dropped unused Sharded balancer. * Dropped some obsolete tests. * Tweaked some tests to account for changed behavior. * Added deadlines to some tests to make them behave better when things hang due to a bug. * Undid earlier bad renaming of object to component. * Split a large test into multiple tests. We also disable the method call panic weavertest since it does not work. Explanation: if a remote weavelet paniced, it's exit raced with weavertest cleanup. If the code reading from the remote weavelet detected the broken connection before weavertest got a chance to mark the test as done, we would print an error message and exit the test process. This interacted poorly with the weavertest/internal/generate test that intentionally triggers a panic in a remote component. To elaborate on what is going wrong, if a subprocess in a weavertest panics, this code will error out: ``` weaver/weavertest/deployer.go Line 374 in c893d9a err := e.Serve(handler) ``` Which leads to stopLocked being called which cancels a context: ``` weaver/weavertest/deployer.go Line 218 in c893d9a d.ctxCancel() ``` This context is used to create all weavelets. When this context is cancelled, the pipes between all envelopes and weavelets shut down. The main weavelet detects this and self-terminates, causing the test to fail even though it should pass. ``` weaver/internal/weaver/remoteweavelet.go Line 178 in c893d9a return w.conn.Serve(w) ``` --- godeps.txt | 1 - internal/net/call/balancer.go | 114 ++--- internal/net/call/call.go | 616 +++++++++++++++-------- internal/net/call/call_test.go | 182 ++----- internal/net/call/options.go | 5 - internal/weaver/remoteweavelet.go | 2 +- internal/weaver/routing.go | 58 ++- internal/weaver/routing_test.go | 46 +- internal/weaver/stub.go | 2 - weavertest/internal/generate/app_test.go | 42 +- 10 files changed, 592 insertions(+), 476 deletions(-) diff --git a/godeps.txt b/godeps.txt index 9d303a0d3..420006b73 100644 --- a/godeps.txt +++ b/godeps.txt @@ -402,7 +402,6 @@ github.com/ServiceWeaver/weaver/internal/net/call go.opentelemetry.io/otel/trace golang.org/x/exp/slog io - math/rand net strings sync diff --git a/internal/net/call/balancer.go b/internal/net/call/balancer.go index 8684fe44b..2c5c8f015 100644 --- a/internal/net/call/balancer.go +++ b/internal/net/call/balancer.go @@ -14,63 +14,55 @@ package call -import ( - "fmt" - "math/rand" -) +// ReplicaConnection is a connection to a single replica. A single Connection +// may consist of many ReplicaConnections (typically one per replica). +type ReplicaConnection interface { + // Address returns the name of the endpoint to which the ReplicaConnection + // is connected. + Address() string +} -// A Balancer picks the endpoint to which which an RPC client performs a call. A -// Balancer should only be used by a single goroutine. +// Balancer manages a set of ReplicaConnections and picks one of them per +// call. A Balancer requires external synchronization (no concurrent calls +// should be made to the same Balancer). // // TODO(mwhittaker): Right now, balancers have no load information about // endpoints. In the short term, we can at least add information about the // number of pending requests for every endpoint. -// -// TODO(mwhittaker): Right now, we pass a balancer the set of all endpoints. We -// instead probably want to pass it only the endpoints for which we have a -// connection. This means we may have to form connections more eagerly. -// -// TODO(mwhittaker): We may want to guarantee that Update() is never called -// with an empty list of addresses. If we don't have addresses, then we don't -// need to do balancing. type Balancer interface { - // Update updates the current set of endpoints from which the Balancer can - // pick. Before Update is called for the first time, the set of endpoints - // is empty. - Update(endpoints []Endpoint) - - // Pick picks an endpoint. Pick is guaranteed to return an endpoint that - // was passed to the most recent call of Update. If there are no endpoints, - // then Pick returns an error that includes Unreachable. - Pick(CallOptions) (Endpoint, error) + // Add adds a ReplicaConnection to the set of connections. + Add(ReplicaConnection) + + // Remove removes a ReplicaConnection from the set of connections. + Remove(ReplicaConnection) + + // Pick picks a ReplicaConnection from the set of connections. + // Pick returns _,false if no connections are available. + Pick(CallOptions) (ReplicaConnection, bool) } // balancerFuncImpl is the implementation of the "functional" balancer // returned by BalancerFunc. type balancerFuncImpl struct { - endpoints []Endpoint - pick func([]Endpoint, CallOptions) (Endpoint, error) + connList + pick func([]ReplicaConnection, CallOptions) (ReplicaConnection, bool) } var _ Balancer = &balancerFuncImpl{} -// BalancerFunc returns a stateless, purely functional load balancer that uses -// the provided picking function. -func BalancerFunc(pick func([]Endpoint, CallOptions) (Endpoint, error)) Balancer { +// BalancerFunc returns a stateless, purely functional load balancer that calls +// pick to pick the connection to use. +func BalancerFunc(pick func([]ReplicaConnection, CallOptions) (ReplicaConnection, bool)) Balancer { return &balancerFuncImpl{pick: pick} } -func (bf *balancerFuncImpl) Update(endpoints []Endpoint) { - bf.endpoints = endpoints -} - -func (bf *balancerFuncImpl) Pick(opts CallOptions) (Endpoint, error) { - return bf.pick(bf.endpoints, opts) +func (bf *balancerFuncImpl) Pick(opts CallOptions) (ReplicaConnection, bool) { + return bf.pick(bf.list, opts) } type roundRobin struct { - endpoints []Endpoint - next int + connList + next int } var _ Balancer = &roundRobin{} @@ -80,37 +72,35 @@ func RoundRobin() *roundRobin { return &roundRobin{} } -func (rr *roundRobin) Update(endpoints []Endpoint) { - rr.endpoints = endpoints -} - -func (rr *roundRobin) Pick(CallOptions) (Endpoint, error) { - if len(rr.endpoints) == 0 { - return nil, fmt.Errorf("%w: no endpoints available", Unreachable) +func (rr *roundRobin) Pick(CallOptions) (ReplicaConnection, bool) { + if len(rr.list) == 0 { + return nil, false } - if rr.next >= len(rr.endpoints) { + if rr.next >= len(rr.list) { rr.next = 0 } - endpoint := rr.endpoints[rr.next] + c := rr.list[rr.next] rr.next += 1 - return endpoint, nil + return c, true } -// Sharded returns a new sharded balancer. -// -// Given a list of n endpoints e1, ..., en, for a request with shard key k, a -// sharded balancer will pick endpoint ei where i = k mod n. If no shard key is -// provided, an endpoint is picked at random. -func Sharded() Balancer { - return BalancerFunc(func(endpoints []Endpoint, opts CallOptions) (Endpoint, error) { - n := len(endpoints) - if n == 0 { - return nil, fmt.Errorf("%w: no endpoints available", Unreachable) - } - if opts.ShardKey == 0 { - // There is no ShardKey. Pick an endpoint at random. - return endpoints[rand.Intn(n)], nil +// connList is a helper type used by balancers to maintain set of connections. +type connList struct { + list []ReplicaConnection +} + +func (cl *connList) Add(c ReplicaConnection) { + cl.list = append(cl.list, c) +} + +func (cl *connList) Remove(c ReplicaConnection) { + for i, elem := range cl.list { + if elem != c { + continue } - return endpoints[opts.ShardKey%uint64(n)], nil - }) + // Replace removed entry with last entry. + cl.list[i] = cl.list[len(cl.list)-1] + cl.list = cl.list[:len(cl.list)-1] + return + } } diff --git a/internal/net/call/call.go b/internal/net/call/call.go index fe51b4511..bb1891c42 100644 --- a/internal/net/call/call.go +++ b/internal/net/call/call.go @@ -38,19 +38,19 @@ package call // // # Client operation // -// A client creates connections to one or more servers and, for every -// connection, starts a background readResponses() goroutine that reads -// messages from the connection. +// For each newly discovered server, the client starts a manage() goroutine +// that connects to server, and then reads messages from the connection. If the +// network connection breaks, manage() reconnects (after a retry delay). // // When the client wants to send an RPC, it selects one of its server -// connections to use, creates a call component, assigns it a new request-id, and -// registers the components in a map in the connection. It then sends a request -// message over the connection and waits for the call component to be marked as +// connections to use, creates a call object, assigns it a new request-id, and +// registers the object in a map in the connection. It then sends a request +// message over the connection and waits for the call object to be marked as // done. // -// When the response arrives, it is picked up by readResponses(). -// readResponses() finds the call component corresponding to the -// request-id in the response, and marks the call component as done which +// When the response arrives, it is picked up by readAndProcessMessage(). +// readAndProcessMessage() finds the call object corresponding to the +// request-id in the response, and marks the call object as done which // wakes up goroutine that initiated the RPC. // // If a client is constructed with a non-constant resolver, the client also @@ -63,9 +63,8 @@ package call // requests on a draining connection are allowed to finish. As soon as a // draining connection has no active calls, the connection closes itself. If // the resolver later returns a new set of endpoints that includes a draining -// connection that hasn't closed itself, the connection is transitioned out of -// the draining phase and is once again allowed to process new RPCs. - +// connection that hasn't closed itself, the draining connection is turned +// back into a normal connection. import ( "bufio" "context" @@ -88,21 +87,8 @@ import ( const ( // Size of the header included in each message. msgHeaderSize = 16 + 8 + traceHeaderLen // handler_key + deadline + trace_context - - // maxReconnectTries is the maximum number of times a reconnecting - // connection will try and create a connection before erroring out. - maxReconnectTries = 3 ) -// TODO: -// - Load balancer -// - API to allow changes to set -// - health-checks -// - track subset that is healthy -// - track load info -// - data structure for efficient picking (randomize? weighted?) -// - pick on call (error if none available) - // Connection allows a client to send RPCs. type Connection interface { // Call makes an RPC over a Connection. @@ -129,33 +115,79 @@ type reconnectingConnection struct { // mu guards the following fields and some of the fields in the // clientConnections inside connections and draining. - mu sync.Mutex - endpoints []Endpoint - connections map[string]*clientConnection // keys are endpoint addresses - draining map[string]*clientConnection // keys are endpoint addresses - closed bool + mu sync.Mutex + conns map[string]*clientConnection + closed bool resolver Resolver cancelResolver func() // cancels the watchResolver goroutine resolverDone sync.WaitGroup // used to wait for watchResolver to finish } +// connState is the state of a clientConnection (connection to a particular +// server replica). missing is a special state used for unknown servers. A +// typical sequence of transitions is: +// +// missing -> disconnected -> checking -> idle <-> active -> draining -> missing +// +// The events that can cause state transition are: +// +// - register: server has shown up in resolver results +// - unregister: server has dropped from resolver results +// - connected: a connection has been successfully made +// - checked: connection has been successfully checked +// - callstart: call starts on connection +// - lastdone: last active call on connection has ended +// - fail: some protocol error is detected on the connection +// - close: reconnectingConnection is being closed +// +// Each event has a corresponding clientConnection method below. See +// those methods for the corresponding state transitions. +type connState int8 + +const ( + missing connState = iota + disconnected // cannot be used for calls + checking // checking new network connection + idle // can be used for calls, no calls in-flight + active // can be used for calls, some calls in-flight + draining // some calls in-flight, no new calls should be added +) + +var connStateNames = []string{ + "missing", + "disconnected", + "checking", + "idle", + "active", + "draining", +} + +func (s connState) String() string { return connStateNames[s] } + // clientConnection manages one network connection on the client-side. type clientConnection struct { - logger *slog.Logger - endpoint Endpoint - c net.Conn - cbuf *bufio.Reader // Buffered reader wrapped around c - wlock sync.Mutex // Guards writes to c - mu *sync.Mutex // Same as reconnectingConnection.mu - draining bool // is this clientConnection draining? - ended bool // has this clientConnection ended? + // Immutable after construction. + rc *reconnectingConnection // owner + canceler func() // Used to cancel goroutine handling connection + logger *slog.Logger + endpoint Endpoint + + wlock sync.Mutex // Guards writes to c + + // Guarded by rc.mu + state connState // current connection state loggedShutdown bool // Have we logged a shutdown error? + inBalancer bool // Is c registered with the balancer? + c net.Conn // Active network connection, or nil + cbuf *bufio.Reader // Buffered reader wrapped around c version version // Version number to use for connection calls map[uint64]*call // In-progress calls lastID uint64 // Last assigned request ID for a call } +var _ ReplicaConnection = &clientConnection{} + // call holds the state for an active call at the client. type call struct { id uint64 @@ -283,9 +315,7 @@ func Connect(ctx context.Context, resolver Resolver, opts ClientOptions) (Connec // Construct the connection. conn := reconnectingConnection{ opts: opts.withDefaults(), - endpoints: []Endpoint{}, - connections: map[string]*clientConnection{}, - draining: map[string]*clientConnection{}, + conns: map[string]*clientConnection{}, resolver: resolver, cancelResolver: func() {}, } @@ -304,7 +334,7 @@ func Connect(ctx context.Context, resolver Resolver, opts ClientOptions) (Connec if !resolver.IsConstant() && version == nil { return nil, errors.New("non-constant resolver returned a nil version") } - if err := conn.updateEndpoints(endpoints); err != nil { + if err := conn.updateEndpoints(ctx, endpoints); err != nil { return nil, err } @@ -330,11 +360,8 @@ func (rc *reconnectingConnection) Close() { return } rc.closed = true - for _, conn := range rc.connections { - conn.endCalls(fmt.Errorf("%w: %s", CommunicationError, "connection closed")) - } - for _, conn := range rc.draining { - conn.endCalls(fmt.Errorf("%w: %s", CommunicationError, "connection closed")) + for _, c := range rc.conns { + c.close() } } closeWithLock() @@ -379,12 +406,11 @@ func (rc *reconnectingConnection) Call(ctx context.Context, h MethodKey, arg []b // connection, we may want to try it again on a different connection. We // may also want to detect that certain connections are bad and avoid them // outright. - conn, err := rc.startCall(ctx, rpc, opts) + conn, nc, err := rc.startCall(ctx, rpc, opts) if err != nil { return nil, err } - - if err := writeMessage(conn.c, &conn.wlock, requestMessage, rpc.id, hdr[:], arg, rc.opts.WriteFlattenLimit); err != nil { + if err := writeMessage(nc, &conn.wlock, requestMessage, rpc.id, hdr[:], arg, rc.opts.WriteFlattenLimit); err != nil { conn.shutdown("client send request", err) conn.endCall(rpc) return nil, fmt.Errorf("%w: %s", CommunicationError, err) @@ -409,7 +435,7 @@ func (rc *reconnectingConnection) Call(ctx context.Context, h MethodKey, arg []b if !haveDeadline || time.Now().Before(deadline) { // Early cancellation. Tell server about it. - if err := writeMessage(conn.c, &conn.wlock, cancelMessage, rpc.id, nil, nil, rc.opts.WriteFlattenLimit); err != nil { + if err := writeMessage(nc, &conn.wlock, cancelMessage, rpc.id, nil, nil, rc.opts.WriteFlattenLimit); err != nil { conn.shutdown("client send cancel", err) } } @@ -443,7 +469,7 @@ func (rc *reconnectingConnection) watchResolver(ctx context.Context, version *Ve // Resolver wishes to be called again after an appropriate delay. continue } - if err := rc.updateEndpoints(endpoints); err != nil { + if err := rc.updateEndpoints(ctx, endpoints); err != nil { logError(rc.opts.Logger, "watchResolver", err) } version = newVersion @@ -454,7 +480,7 @@ func (rc *reconnectingConnection) watchResolver(ctx context.Context, version *Ve // updateEndpoints updates the set of endpoints. Existing connections are // retained, and stale connections are closed. // REQUIRES: rc.mu is not held. -func (rc *reconnectingConnection) updateEndpoints(endpoints []Endpoint) error { +func (rc *reconnectingConnection) updateEndpoints(ctx context.Context, endpoints []Endpoint) error { rc.mu.Lock() defer rc.mu.Unlock() @@ -462,190 +488,274 @@ func (rc *reconnectingConnection) updateEndpoints(endpoints []Endpoint) error { return fmt.Errorf("updateEndpoints on closed Connection") } - // Remove fully drained connections since they have been closed already and - // cannot be reused. - rc.removeDrainedConnections() - - // Retain existing connections. - connections := make(map[string]*clientConnection, len(endpoints)) + // Make new endpoints. + keep := make(map[string]struct{}, len(endpoints)) for _, endpoint := range endpoints { addr := endpoint.Address() - if conn, ok := rc.connections[addr]; ok { - connections[addr] = conn - delete(rc.connections, addr) - } else if conn, ok := rc.draining[addr]; ok { - conn.draining = false - connections[addr] = conn - delete(rc.draining, addr) - } else { - // If we don't have an existing connection, it will be created - // on-demand when Call is invoked. We don't have to insert anything - // into rc.connections. + keep[addr] = struct{}{} + if _, ok := rc.conns[addr]; !ok { + // New endpoint, create connection and manage it. + ctx, cancel := context.WithCancel(ctx) + c := &clientConnection{ + rc: rc, + canceler: cancel, + logger: rc.opts.Logger, + endpoint: endpoint, + calls: map[uint64]*call{}, + lastID: 0, + } + rc.conns[addr] = c + c.register() + go c.manage(ctx) } } - // Update our state. - rc.endpoints = endpoints - for addr, conn := range rc.connections { - conn.draining = true - rc.draining[addr] = conn + // Drop old endpoints. + for addr, c := range rc.conns { + if _, ok := keep[addr]; ok { + // Still live, so keep it. + continue + } + c.unregister() } - rc.connections = connections - rc.opts.Balancer.Update(endpoints) - - // Close draining connections that don't have any pending requests. If a - // draining connection does have pending requests, then the connection will - // close itself when it finishes processing all of its requests. - rc.removeDrainedConnections() - - // TODO(mwhittaker): Close draining connections after a delay? return nil } -// removeDrainedConnections closes and removes any fully drained connections -// from rc.draining. -// -// REQUIRES: rc.mu is held. -func (rc *reconnectingConnection) removeDrainedConnections() { - for addr, conn := range rc.draining { - conn.endIfDrained() - if conn.ended { - delete(rc.draining, addr) +// startCall registers a new in-progress call. +// REQUIRES: rc.mu is not held. +func (rc *reconnectingConnection) startCall(ctx context.Context, rpc *call, opts CallOptions) (*clientConnection, net.Conn, error) { + for r := retry.Begin(); r.Continue(ctx); { + rc.mu.Lock() + if rc.closed { + rc.mu.Unlock() + return nil, nil, fmt.Errorf("Call on closed Connection") + } + + replica, ok := rc.opts.Balancer.Pick(opts) + if !ok { + rc.mu.Unlock() + continue } + + c, ok := replica.(*clientConnection) + if !ok { + rc.mu.Unlock() + return nil, nil, fmt.Errorf("internal error: wrong connection type %#v returned by load balancer", replica) + } + + c.lastID++ + rpc.id = c.lastID + c.calls[rpc.id] = rpc + c.callstart() + nc := c.c + rc.mu.Unlock() + + return c, nc, nil } + + return nil, nil, ctx.Err() } -// startCall registers a new in-progress call. -// REQUIRES: rc.mu is not held. -func (rc *reconnectingConnection) startCall(ctx context.Context, rpc *call, opts CallOptions) (*clientConnection, error) { - rc.mu.Lock() - defer rc.mu.Unlock() +func (c *clientConnection) Address() string { + return c.endpoint.Address() +} - if rc.closed { - return nil, fmt.Errorf("Call on closed Connection") +// State transition actions: all of these are called with rc.mu held. + +func (c *clientConnection) register() { + switch c.state { + case missing: + c.setState(disconnected) + case draining: + // We were attempting to get rid of the old connection, but it + // seems like the server-side problem was transient, so we + // resurrect the draining connection into a non-draining state. + // + // New state is active instead of idle since state==draining + // implies there is at least one call in-flight. + c.setState(active) } +} - if len(rc.endpoints) == 0 { - return nil, fmt.Errorf("%w: no endpoints available", Unreachable) +func (c *clientConnection) unregister() { + switch c.state { + case disconnected, checking, idle: + c.setState(missing) + case active: + c.setState(draining) } +} - // Note that it is important to hold rc.mu when calling Pick(), and it's - // important that we index into rc.connections with addr while still - // holding rc.mu. Otherwise, a Pick() call could operate on a stale set of - // endpoints and return an endpoint that does not exist in rc.connections. - var balancer = rc.opts.Balancer - if opts.Balancer != nil { - balancer = opts.Balancer - balancer.Update(rc.endpoints) +func (c *clientConnection) connected() { + switch c.state { + case disconnected: + c.setState(checking) } +} - // TODO(mwhittaker): Think about the other places where we can perform - // automatic retries. We need to be careful about non-idempotent - // operations. - var connectErr error - for i := 0; i < maxReconnectTries; i++ { - endpoint, err := balancer.Pick(opts) - if err != nil { - return nil, err - } - addr := endpoint.Address() +func (c *clientConnection) checked() { + switch c.state { + case checking: + c.setState(idle) + } +} - if conn, ok := rc.connections[addr]; !ok || conn.ended { - c, err := rc.reconnect(ctx, endpoint) - if err != nil { - connectErr = err - continue - } - rc.connections[addr] = c - } +func (c *clientConnection) callstart() { + switch c.state { + case idle: + c.setState(active) + } +} - c := rc.connections[addr] - c.lastID++ - rpc.id = c.lastID - c.calls[rpc.id] = rpc - return c, nil +func (c *clientConnection) lastdone() { + switch c.state { + case active: + c.setState(idle) + case draining: + c.setState(missing) } - return nil, connectErr } -// reconnect establishes (or re-establishes) the network connection to the server. -// REQUIRES: rc.mu is held. -func (rc *reconnectingConnection) reconnect(ctx context.Context, endpoint Endpoint) (*clientConnection, error) { - nc, err := endpoint.Dial(ctx) - if err != nil { - return nil, fmt.Errorf("%w: %s", CommunicationError, err) +func (c *clientConnection) fail(details string, err error) { + if !c.loggedShutdown { + c.loggedShutdown = true + logError(c.logger, details, err) + } + switch c.state { + case checking, idle, active: + c.setState(disconnected) + case draining: + c.setState(missing) + } +} + +func (c *clientConnection) close() { + c.setState(missing) +} + +// checkInvariants verifies clientConnection invariants. +func (c *clientConnection) checkInvariants() { + s := c.state + + // connection in reconnectingConnection.conns iff state not in {missing} + if _, ok := c.rc.conns[c.endpoint.Address()]; ok != (s != missing) { + panic(fmt.Sprintf("%v connection: wrong connection table presence %v", s, ok)) + } + + // has net.Conn iff state in {checking, idle, active, draining} + if (c.c != nil) != (s == checking || s == idle || s == active || s == draining) { + panic(fmt.Sprintf("%v connection: wrong net.Conn %v", s, c.c)) + } + + // connection is in the balancer iff state in {idle, active} + if c.inBalancer != (s == idle || s == active) { + panic(fmt.Sprintf("%v connection: wrong balancer presence %v", s, c.inBalancer)) + } + + // len(calls) > 0 iff state in {active, draining} + if (len(c.calls) != 0) != (s == active || s == draining) { + panic(fmt.Sprintf("%v connection: wrong number of calls %d", s, len(c.calls))) } - conn := &clientConnection{ - logger: rc.opts.Logger, - endpoint: endpoint, - c: nc, - cbuf: bufio.NewReader(nc), - mu: &rc.mu, - version: initialVersion, // Updated when we hear from server - calls: map[uint64]*call{}, - lastID: 0, +} + +// setState transitions to state s and updates any related state. +func (c *clientConnection) setState(s connState) { + // idle<-> active transitions may happen a lot, so short-circuit them + // by avoiding logging and full invariant maintenance. + if c.state == active && s == idle { + c.state = idle + if len(c.calls) != 0 { + panic(fmt.Sprintf("%v connection: wrong number of calls %d", s, len(c.calls))) + } + return + } else if c.state == idle && s == active { + c.state = active + if len(c.calls) == 0 { + panic(fmt.Sprintf("%v connection: wrong number of calls %d", s, len(c.calls))) + } + return + } + + c.logger.Info(fmt.Sprintf("connection %p", c), "addr", c.endpoint.Address(), "from", c.state, "to", s, "b", c.inBalancer) + c.state = s + + // Fix membership in rc.conns. + if s == missing { + delete(c.rc.conns, c.endpoint.Address()) + if c.canceler != nil { + c.canceler() // Forces retry loop to end early + c.canceler = nil + } + } // else: caller is responsible for adding c to rc.conns + + // Fix net.Conn presence. + if s == missing || s == disconnected { + if c.c != nil { + c.c.Close() + c.c = nil + c.cbuf = nil + } + } // else: caller is responsible for setting c.c and c.cbuf + + // Fix balancer membership. + if s == idle || s == active { + if !c.inBalancer { + c.rc.opts.Balancer.Add(c) + c.inBalancer = true + } + } else { + if c.inBalancer { + c.rc.opts.Balancer.Remove(c) + c.inBalancer = false + } } - if err := writeVersion(conn.c, &conn.wlock); err != nil { - return nil, fmt.Errorf("%w: client send version: %s", CommunicationError, err) + + // Fix in-flight calls. + if s == active || s == draining { + // Keep calls live + } else { + // XXX Pass in detail and/or error + c.endCalls(fmt.Errorf("%w: %v", CommunicationError, s)) } - go conn.readResponses() - return conn, nil + + c.checkInvariants() } func (c *clientConnection) endCall(rpc *call) { - c.mu.Lock() - defer c.mu.Unlock() + c.rc.mu.Lock() + defer c.rc.mu.Unlock() delete(c.calls, rpc.id) - c.endIfDrained() + if len(c.calls) == 0 { + c.lastdone() + } } func (c *clientConnection) findAndEndCall(id uint64) *call { - c.mu.Lock() - defer c.mu.Unlock() + c.rc.mu.Lock() + defer c.rc.mu.Unlock() rpc := c.calls[id] if rpc != nil { delete(c.calls, id) - c.endIfDrained() + if len(c.calls) == 0 { + c.lastdone() + } } return rpc } -// endIfDrained closes c if it is a fully drained connection. -// -// REQUIRES: c.mu is held. -func (c *clientConnection) endIfDrained() { - // Note that endIfDrained closes c, but it doesn't remove c from - // reconnectingConnection.draining. rc.updateEndpoints will remove drained - // connections from rc.draining. This approach leaves some drained - // connections around, but it simplifies the code. Specifically, a - // reconnectingConnection may modify a child clientConnection, but a - // clientConnection never modifies its parent reconnectingConnection. - if c.draining && len(c.calls) == 0 { - c.endCalls(fmt.Errorf("connection drained")) - } -} - // shutdown processes an error detected while operating on a connection. // It closes the network connection and cancels all requests in progress on the connection. // REQUIRES: c.mu is not held. func (c *clientConnection) shutdown(details string, err error) { - c.mu.Lock() - defer c.mu.Unlock() - if !c.loggedShutdown { - c.loggedShutdown = true - logError(c.logger, "shutdown: "+details, err) - } - - // Cancel all in-progress calls. - c.endCalls(fmt.Errorf("%w: %s: %s", CommunicationError, details, err)) + c.rc.mu.Lock() + defer c.rc.mu.Unlock() + c.fail(details, err) } // endCalls closes the network connection and ends any in-progress calls. // REQUIRES: c.mu is held. func (c *clientConnection) endCalls(err error) { - c.c.Close() - c.ended = true for id, active := range c.calls { active.err = err atomic.StoreUint32(&active.done, 1) @@ -654,46 +764,114 @@ func (c *clientConnection) endCalls(err error) { } } -// readResponses runs on the client side reading messages sent over a connection by the server. -func (c *clientConnection) readResponses() { - for { - mt, id, msg, err := readMessage(c.cbuf) - if err != nil { - c.shutdown("client read", err) - return +// manage handles a live clientConnection until it becomes missing. +func (c *clientConnection) manage(ctx context.Context) { + for r := retry.Begin(); r.Continue(ctx); { + progress := c.connectOnce(ctx) + if progress { + r.Reset() } + } +} - switch mt { - case versionMessage: - v, err := getVersion(id, msg) - if err != nil { - c.shutdown("client read", err) - return - } - c.mu.Lock() - c.version = v - c.mu.Unlock() - case responseMessage, responseError: - rpc := c.findAndEndCall(id) - if rpc == nil { - continue // May have been canceled - } - if mt == responseError { - if err, ok := decodeError(msg); ok { - rpc.err = err - } else { - rpc.err = fmt.Errorf("%w: could not decode error", CommunicationError) - } +// connectOnce dials once to the endpoint and manages the resulting connection. +// It returns true if some communication happened successfully over the connection. +func (c *clientConnection) connectOnce(ctx context.Context) bool { + // Dial the connection. + nc, err := c.endpoint.Dial(ctx) + if err != nil { + logError(c.logger, "dial", err) + return false + } + defer nc.Close() + + c.rc.mu.Lock() + defer c.rc.mu.Unlock() // Also temporarily unlocked below + c.c = nc + c.cbuf = bufio.NewReader(nc) + c.loggedShutdown = false + c.connected() + + // Handshake to get the peer version and verify that it is live. + if err := c.exchangeVersions(); err != nil { + c.fail("handshake", err) + return false + } + c.checked() + + for c.state == idle || c.state == active || c.state == draining { + if err := c.readAndProcessMessage(); err != nil { + c.fail("client read", err) + } + } + return true +} + +// exchangeVersions sends client version to server and waits for the server version. +func (c *clientConnection) exchangeVersions() error { + nc, buf := c.c, c.cbuf + + // Do not hold mutex while reading from the network. + c.rc.mu.Unlock() + defer c.rc.mu.Lock() + + if err := writeVersion(nc, &c.wlock); err != nil { + return err + } + mt, id, msg, err := readMessage(buf) + if err != nil { + return err + } + if mt != versionMessage { + return fmt.Errorf("wrong message type %d, expecting %d", mt, versionMessage) + } + v, err := getVersion(id, msg) + if err != nil { + return err + } + c.version = v + return nil +} + +// readAndProcessMessage reads and handles one message sent from the server. +func (c *clientConnection) readAndProcessMessage() error { + buf := c.cbuf + + // Do not hold mutex while reading from the network. + c.rc.mu.Unlock() + defer c.rc.mu.Lock() + + mt, id, msg, err := readMessage(buf) + if err != nil { + return err + } + switch mt { + case versionMessage: + _, err := getVersion(id, msg) + if err != nil { + return err + } + // Ignore versions sent after initial hand-shake + case responseMessage, responseError: + rpc := c.findAndEndCall(id) + if rpc == nil { + return nil // May have been canceled + } + if mt == responseError { + if err, ok := decodeError(msg); ok { + rpc.err = err } else { - rpc.response = msg + rpc.err = fmt.Errorf("%w: could not decode error", CommunicationError) } - atomic.StoreUint32(&rpc.done, 1) - close(rpc.doneSignal) - default: - c.shutdown("client read", fmt.Errorf("invalid response %d", mt)) - return + } else { + rpc.response = msg } + atomic.StoreUint32(&rpc.done, 1) + close(rpc.doneSignal) + default: + return fmt.Errorf("invalid response %d", mt) } + return nil } // readRequests runs on the server side reading messages sent over a connection by the client. diff --git a/internal/net/call/call_test.go b/internal/net/call/call_test.go index 07c377dbc..d942c778b 100644 --- a/internal/net/call/call_test.go +++ b/internal/net/call/call_test.go @@ -387,16 +387,14 @@ func (h hangingConn) Write(b []byte) (n int, err error) { // hangingEndpoint returns a hangingConn. type hangingEndpoint struct { - addr string + call.Endpoint } var _ call.Endpoint = hangingEndpoint{} -func (h hangingEndpoint) Address() string { return h.addr } - func (h hangingEndpoint) Dial(ctx context.Context) (net.Conn, error) { // Make real connection and wrap it inside a hangingConn. - c, err := call.NetEndpoint{"tcp", h.addr}.Dial(ctx) + c, err := h.Endpoint.Dial(ctx) if err != nil { return nil, err } @@ -494,9 +492,9 @@ func checkQuickCancel(ctx context.Context, t *testing.T, c call.Connection) erro return err } -func testCall(t *testing.T, client call.Connection) { +func testCall(ctx context.Context, t *testing.T, client call.Connection) { const arg = "hello" - result, err := client.Call(context.Background(), echoKey, []byte(arg), call.CallOptions{}) + result, err := client.Call(ctx, echoKey, []byte(arg), call.CallOptions{}) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -621,7 +619,7 @@ func TestSingleTCPServer(t *testing.T) { name string f func(*testing.T, call.Connection) }{ - {"TestCall", testCall}, + {"TestCall", func(t *testing.T, c call.Connection) { testCall(context.Background(), t, c) }}, {"TestConcurrentCalls", testConcurrentCalls}, {"TestError", testError}, {"TestDeadlineHandling", testDeadlineHandling}, @@ -640,11 +638,11 @@ func TestSingleTCPServer(t *testing.T) { for resolverName, maker := range resolverMakers { for _, protocol := range protocols { client := getClientConn(t, protocol, endpoints[protocol], maker) - defer client.Close() for _, subtest := range subtests { name := fmt.Sprintf("Shared/%s/%s/%s", resolverName, protocol, subtest.name) t.Run(name, func(t *testing.T) { subtest.f(t, client) }) } + client.Close() } } @@ -654,8 +652,8 @@ func TestSingleTCPServer(t *testing.T) { for _, subtest := range subtests { name := fmt.Sprintf("Fresh/%s/%s/%s", resolverName, protocol, subtest.name) client := getClientConn(t, protocol, endpoints[protocol], maker) - defer client.Close() t.Run(name, func(t *testing.T) { subtest.f(t, client) }) + client.Close() } } } @@ -697,13 +695,21 @@ func TestMultipleEndpoints(t *testing.T) { } defer client.Close() - for i := 0; i < 2*n; i++ { + // Run a bunch of calls and check that they spread out over the replicas. + count := map[string]int{} + const attempts = 100 + for i := 0; i < attempts; i++ { result, err := client.Call(ctx, whoKey, []byte{}, call.CallOptions{}) if err != nil { t.Fatalf("unexpected error: %v", err) } - if got, want := string(result), strconv.Itoa(i%n); got != want { - t.Fatalf("bad result: got %q, want %q", got, want) + count[string(result)]++ + } + want := attempts / n + for _, k := range []string{"0", "1", "2"} { + got := count[k] + if got < want/2 || got > want*2 { + t.Errorf("replica %s got %d, expecting ~%d", k, got, want) } } }) @@ -735,102 +741,6 @@ func TestChangingEndpoints(t *testing.T) { } } -// TestShardedBalancer tests that requests are routed correctly using a sharded -// load balancer. -func TestShardedBalancer(t *testing.T) { - ctx := context.Background() - s1, s2, s3 := server(t, "1"), server(t, "2"), server(t, "3") - resolver := call.NewConstantResolver(s1, s2, s3) - opts := call.ClientOptions{ - Balancer: call.Sharded(), - Logger: logger(t), - } - client, err := call.Connect(ctx, resolver, opts) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - defer client.Close() - - // Route with a key. - for i := 1; i < 10; i++ { // Skip 0, which indicates no key. - key := uint64(i) - result, err := client.Call(ctx, whoKey, []byte{}, call.CallOptions{ShardKey: key}) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - wants := []string{"1", "2", "3"} - want := wants[i%3] - if got := string(result); got != want { - t.Fatalf("bad result: got %q, want %q", got, want) - } - } - - // Route without a key. - for i := 0; i < 10; i++ { - result, err := client.Call(ctx, whoKey, []byte{}, call.CallOptions{}) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if got := string(result); got != "1" && got != "2" && got != "3" { - t.Fatalf("bad result: got %q, want %q, %q, or %q", got, "1", "2", "3") - } - } -} - -// TestCallOptionsBalancer tests that requests are routed correctly using a -// per-call load balancer. -func TestCallOptionsBalancer(t *testing.T) { - // Test plan: Create three servers named 1, 2, and 3. Create a call.Client - // to these servers using a sharded balancer. Invoke the who method, - // checking that the request with key i is routed to server i % 3. - ctx := context.Background() - s1, s2, s3 := server(t, "1"), server(t, "2"), server(t, "3") - resolver := call.NewConstantResolver(s1, s2, s3) - opts := call.ClientOptions{Logger: logger(t)} - client, err := call.Connect(ctx, resolver, opts) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - defer client.Close() - - // Route using per-call balancer. - for _, test := range []struct { - e call.Endpoint - want string - }{ - {s1, "1"}, - {s2, "2"}, - {s3, "3"}, - } { - b := call.BalancerFunc(func([]call.Endpoint, call.CallOptions) (call.Endpoint, error) { - return test.e, nil - }) - for i := 0; i < 10; i++ { - result, err := client.Call(ctx, whoKey, []byte{}, call.CallOptions{Balancer: b}) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if got := string(result); got != test.want { - t.Fatalf("bad result: got %q, want %q", got, test.want) - } - } - } - - // Route with the default balancer. - for i := 0; i < 10; i++ { - result, err := client.Call(ctx, whoKey, []byte{}, call.CallOptions{}) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if got := string(result); got != "1" && got != "2" && got != "3" { - t.Fatalf("bad result: got %q, want %q, %q, or %q", got, "1", "2", "3") - } - } -} - // TestNoEndpointsConstant tests that it is an error to call Connect with a // constant resolver that returns no endpoints. func TestNoEndpointsConstant(t *testing.T) { @@ -861,26 +771,29 @@ func TestNoEndpointsNonConstant(t *testing.T) { } // Making a call without any endpoints is an error though. - _, err = client.Call(ctx, echoKey, []byte{}, call.CallOptions{}) - if err == nil { - t.Fatal("unexpected success when expecting error") - } - if got, want := err, call.Unreachable; !errors.Is(got, want) { + sub, cancel := context.WithTimeout(ctx, shortDelay) + _, err = client.Call(sub, echoKey, []byte{}, call.CallOptions{}) + cancel() + if got, want := err, context.DeadlineExceeded; !errors.Is(got, want) { t.Fatalf("bad error: got %v, want %v", got, want) } // Add an endpoint and let the update propagate. resolver.Endpoints(server(t, "server")) waitUntil(t, func() bool { - _, err = client.Call(ctx, echoKey, []byte{}, call.CallOptions{}) + sub, cancel := context.WithTimeout(ctx, shortDelay) + defer cancel() + _, err = client.Call(sub, echoKey, []byte{}, call.CallOptions{}) return err == nil }) // Remove the endpoint. resolver.Endpoints() waitUntil(t, func() bool { - _, err = client.Call(ctx, echoKey, []byte{}, call.CallOptions{}) - return err != nil && errors.Is(err, call.Unreachable) + sub, cancel := context.WithTimeout(ctx, shortDelay) + defer cancel() + _, err = client.Call(sub, echoKey, []byte{}, call.CallOptions{}) + return errors.Is(err, context.DeadlineExceeded) }) } @@ -1165,21 +1078,12 @@ func TestRefreshDraining(t *testing.T) { if _, err := client.Call(ctx, sleepKey, []byte((2 * delaySlop).String()), call.CallOptions{}); err != nil { t.Fatalf("unexpected error: %v", err) } - - // If m1 were draining, then it would be closed because it has no active - // connections. The connection is no longer draining though, so it should - // still be open. - if m.Closed() { - t.Fatalf("connection closed") - } } // TestUnhealthyChannels attempts to use a resolver that returns a mix of // healthy and unhealthy endpoints and verifies that calls succeed by being // routed to the healthy backend. func TestInitiallyUnhealthyEndpoint(t *testing.T) { - t.Skip("TODO(sanjay): Implement unhealthy channel avoidance") - // To fix this, we will avoid connections on which an initial // health-check has not completed. There are some options here: // @@ -1204,18 +1108,24 @@ func TestInitiallyUnhealthyEndpoint(t *testing.T) { // Start a good server and a bad server. ct := startTest(t) - bad, good := ct.startTCPServer(), ct.startTCPServer() - resolver := call.NewConstantResolver(hangingEndpoint{bad}, call.TCP(good)) + badAddr, goodAddr := ct.startTCPServer(), ct.startTCPServer() + bad := hangingEndpoint{call.TCP(badAddr)} + good := call.TCP(goodAddr) + resolver := call.NewConstantResolver(bad, good) // Connect via a balancer whose picking decision we control. - var target atomic.Int32 - balancer := call.BalancerFunc(func(endpoints []call.Endpoint, options call.CallOptions) (call.Endpoint, error) { - i := int(target.Load()) - fmt.Fprintf(os.Stderr, "%d from %v\n", i, endpoints) - if i < 0 || i >= len(endpoints) { - return nil, fmt.Errorf("%w: no endpoints available", call.Unreachable) + var useGood atomic.Bool + balancer := call.BalancerFunc(func(conns []call.ReplicaConnection, options call.CallOptions) (call.ReplicaConnection, bool) { + want := bad.Address() + if useGood.Load() { + want = good.Address() + } + for _, c := range conns { + if c.Address() == want { + return c, true + } } - return endpoints[i], nil + return nil, false }) client, err := call.Connect(ct.ctx, resolver, call.ClientOptions{ Balancer: balancer, @@ -1228,11 +1138,11 @@ func TestInitiallyUnhealthyEndpoint(t *testing.T) { // Switch to using the good endpoint after a delay. ct.fork(func() { time.Sleep(shortDelay) - target.Store(1) + useGood.Store(true) }) start := time.Now() - testCall(t, client) + testCall(ct.ctx, t, client) elapsed := time.Since(start) if elapsed < shortDelay { t.Fatalf("call completed too soon: after %v, expecting at least %v", elapsed, shortDelay) diff --git a/internal/net/call/options.go b/internal/net/call/options.go index 4fe0646fb..623f11787 100644 --- a/internal/net/call/options.go +++ b/internal/net/call/options.go @@ -65,11 +65,6 @@ type CallOptions struct { // TODO(mwhittaker): Figure out a way to have 0 be a valid shard key. Could // change to *uint64 for example. ShardKey uint64 - - // Balancer, if not nil, is the Balancer to use for a call, instead of the - // Balancer that the client was constructed with (provided in - // ClientOptions). - Balancer Balancer } // withDefaults returns a copy of the ClientOptions with zero values replaced diff --git a/internal/weaver/remoteweavelet.go b/internal/weaver/remoteweavelet.go index 75040508d..f158b8509 100644 --- a/internal/weaver/remoteweavelet.go +++ b/internal/weaver/remoteweavelet.go @@ -340,6 +340,7 @@ func (w *RemoteWeavelet) makeStub(reg *codegen.Registration, resolver *routingRe // Create the client connection. w.syslogger.Debug("Creating a connection to a remote component...", "component", reg.Name) opts := call.ClientOptions{ + Balancer: balancer, Logger: w.syslogger, WriteFlattenLimit: 4 << 10, } @@ -366,7 +367,6 @@ func (w *RemoteWeavelet) makeStub(reg *codegen.Registration, resolver *routingRe component: reg.Name, conn: conn, methods: methods, - balancer: balancer, tracer: w.tracer, }, nil } diff --git a/internal/weaver/routing.go b/internal/weaver/routing.go index d98526c6d..089f432b3 100644 --- a/internal/weaver/routing.go +++ b/internal/weaver/routing.go @@ -28,22 +28,44 @@ import ( // routingBalancer balances requests according to a routing assignment. type routingBalancer struct { - balancer call.Balancer // default balancer + balancer call.Balancer // balancer to use for non-routed calls tlsConfig *tls.Config // tls config to use; may be nil. mu sync.RWMutex assignment *protos.Assignment index index + + // Map from address to connection. We currently allow just one + // connection per address. + // Guarded by mu. + conns map[string]call.ReplicaConnection } // newRoutingBalancer returns a new routingBalancer. func newRoutingBalancer(tlsConfig *tls.Config) *routingBalancer { - return &routingBalancer{balancer: call.RoundRobin(), tlsConfig: tlsConfig} + return &routingBalancer{ + balancer: call.RoundRobin(), + tlsConfig: tlsConfig, + conns: map[string]call.ReplicaConnection{}, + } } -// Update implements the call.Balancer interface. -func (rb *routingBalancer) Update(endpoints []call.Endpoint) { - rb.balancer.Update(endpoints) +// Add adds c to the set of connections we are balancing across. +func (rb *routingBalancer) Add(c call.ReplicaConnection) { + rb.balancer.Add(c) + + rb.mu.Lock() + defer rb.mu.Unlock() + rb.conns[c.Address()] = c +} + +// Remove removes c from the set of connections we are balancing across. +func (rb *routingBalancer) Remove(c call.ReplicaConnection) { + rb.balancer.Remove(c) + + rb.mu.Lock() + defer rb.mu.Unlock() + delete(rb.conns, c.Address()) } // update updates the balancer with the provided assignment @@ -60,7 +82,7 @@ func (rb *routingBalancer) update(assignment *protos.Assignment) { } // Pick implements the call.Balancer interface. -func (rb *routingBalancer) Pick(opts call.CallOptions) (call.Endpoint, error) { +func (rb *routingBalancer) Pick(opts call.CallOptions) (call.ReplicaConnection, bool) { if opts.ShardKey == 0 { // If the method we're calling is not sharded (which is guaranteed to // be true for nonsharded components), then the shard key is 0. @@ -88,17 +110,21 @@ func (rb *routingBalancer) Pick(opts call.CallOptions) (call.Endpoint, error) { return rb.balancer.Pick(opts) } - // TODO(mwhittaker): Double check that the endpoint in the slice is one of - // the endpoints in rb.endpoints. - // - // TODO(mwhittaker): Parse the endpoints when an assignment is received, - // rather than once per call. - addr := slice.replicas[rand.Intn(len(slice.replicas))] - endpoints, err := parseEndpoints([]string{addr}, rb.tlsConfig) - if err != nil { - return nil, err + // Search for an available ReplicConnection starting at a random offset. + // TODO(sanjay):Precompute the set of available ReplicaConnections per slice. + offset := rand.Intn(len(slice.replicas)) + rb.mu.RLock() + defer rb.mu.RUnlock() + for i, n := 0, len(slice.replicas); i < n; i++ { + offset++ + if offset == n { + offset = 0 + } + if c, ok := rb.conns[slice.replicas[offset]]; ok { + return c, true + } } - return endpoints[0], nil + return nil, false } // routingResolver is a dummy resolver that returns whatever endpoints are diff --git a/internal/weaver/routing_test.go b/internal/weaver/routing_test.go index 042488c2b..9e602c465 100644 --- a/internal/weaver/routing_test.go +++ b/internal/weaver/routing_test.go @@ -39,6 +39,11 @@ func (ne nilEndpoint) Address() string { return ne.Addr } +// fakeConn is a fake call.ReplicaConnection used for testing. +type fakeConn string + +func (f fakeConn) Address() string { return string(f) } + // TestRoutingBalancerNoAssignment tests that a routingBalancer with no // assignment will use its default balancer instead. func TestRoutingBalancerNoAssignment(t *testing.T) { @@ -47,16 +52,12 @@ func TestRoutingBalancerNoAssignment(t *testing.T) { {ShardKey: 1}, } { t.Run(fmt.Sprint(opts.ShardKey), func(t *testing.T) { - b := call.BalancerFunc(func([]call.Endpoint, call.CallOptions) (call.Endpoint, error) { - return nilEndpoint{"a"}, nil + b := call.BalancerFunc(func([]call.ReplicaConnection, call.CallOptions) (call.ReplicaConnection, bool) { + return fakeConn("a"), false }) rb := routingBalancer{balancer: b} - got, err := rb.Pick(opts) - if err != nil { - t.Fatal(err) - } - if want := (nilEndpoint{"a"}); got != want { - t.Fatalf("rb.Pick(%v): got %v, want %v", opts, got, want) + if got, ok := rb.Pick(opts); ok { + t.Fatalf("r.Pick unexpectedly returned %s", got.Address()) } }) } @@ -65,39 +66,42 @@ func TestRoutingBalancerNoAssignment(t *testing.T) { // TestRoutingBalancer tests that a routingBalancer with an assignment will // pick endpoints using its assignment. func TestRoutingBalancer(t *testing.T) { - b := call.BalancerFunc(func([]call.Endpoint, call.CallOptions) (call.Endpoint, error) { - return nil, fmt.Errorf("default balancer called") + b := call.BalancerFunc(func([]call.ReplicaConnection, call.CallOptions) (call.ReplicaConnection, bool) { + t.Fatal("default balancer called") + return nil, false }) - rb := routingBalancer{balancer: b} + rb := routingBalancer{balancer: b, conns: map[string]call.ReplicaConnection{}} assignment := &protos.Assignment{ Slices: []*protos.Assignment_Slice{ { Start: 0, - Replicas: []string{"tcp://a"}, + Replicas: []string{"a"}, }, { Start: 100, - Replicas: []string{"tcp://b"}, + Replicas: []string{"b"}, }, }, } rb.update(assignment) + rb.Add(fakeConn("a")) + rb.Add(fakeConn("b")) for _, test := range []struct { shardKey uint64 - want call.NetEndpoint + want string }{ - {20, call.TCP("a")}, - {120, call.TCP("b")}, + {20, "a"}, + {120, "b"}, } { t.Run(fmt.Sprint(test.shardKey), func(t *testing.T) { - got, err := rb.Pick(call.CallOptions{ShardKey: test.shardKey}) - if err != nil { - t.Fatal(err) + got, ok := rb.Pick(call.CallOptions{ShardKey: test.shardKey}) + if !ok { + t.Fatal("did not find replica") } - if got != test.want { - t.Fatalf("rb.Pick(%d): got %v, want %v", test.shardKey, got, test.want) + if got.Address() != test.want { + t.Fatalf("rb.Pick(%d): got %s, want %s", test.shardKey, got.Address(), test.want) } }) } diff --git a/internal/weaver/stub.go b/internal/weaver/stub.go index 9bfb7a326..adc8ae5eb 100644 --- a/internal/weaver/stub.go +++ b/internal/weaver/stub.go @@ -27,7 +27,6 @@ type stub struct { component string // name of the remote component conn call.Connection // connection to talk to the remote component methods []call.MethodKey // keys for the remote component methods - balancer call.Balancer // if not nil, component load balancer tracer trace.Tracer // component tracer } @@ -42,7 +41,6 @@ func (s *stub) Tracer() trace.Tracer { func (s *stub) Run(ctx context.Context, method int, args []byte, shardKey uint64) ([]byte, error) { opts := call.CallOptions{ ShardKey: shardKey, - Balancer: s.balancer, } return s.conn.Call(ctx, s.methods[method], args, opts) } diff --git a/weavertest/internal/generate/app_test.go b/weavertest/internal/generate/app_test.go index e0ffa8647..1c911631d 100644 --- a/weavertest/internal/generate/app_test.go +++ b/weavertest/internal/generate/app_test.go @@ -25,30 +25,40 @@ import ( ) // TODO(mwhittaker): Induce an error in the encoding, decoding, and RPC call. -func TestErrors(t *testing.T) { + +func TestSuccess(t *testing.T) { ctx := context.Background() weavertest.Multi.Test(t, func(t *testing.T, client testApp) { - // Trigger an application error. Verify that an application error - // is returned. - x, err := client.Get(ctx, "foo", appError) - if err == nil || !strings.Contains(err.Error(), "key foo not found") { - t.Fatalf("expected an application error; got: %v", err) + // Do a normal get operation. Verify that the operation succeeds. + x, err := client.Get(ctx, "foo", noError) + if err != nil { + t.Fatal(err) } if x != 42 { t.Fatalf("client.Get: got %d, want 42", x) } + }) +} - // Do a normal get operation. Verify that the operation succeeds. - x, err = client.Get(ctx, "foo", noError) - if err != nil { - t.Fatal(err) +func TestAppError(t *testing.T) { + ctx := context.Background() + weavertest.Multi.Test(t, func(t *testing.T, client testApp) { + // Trigger an application error. Verify that an application error + // is returned. + x, err := client.Get(ctx, "foo", appError) + if err == nil || !strings.Contains(err.Error(), "key foo not found") { + t.Fatalf("expected an application error; got: %v", err) } if x != 42 { t.Fatalf("client.Get: got %d, want 42", x) } + }) +} - // Check custom error. - _, err = client.Get(ctx, "custom", customError) +func TestCustomError(t *testing.T) { + ctx := context.Background() + weavertest.Multi.Test(t, func(t *testing.T, client testApp) { + _, err := client.Get(ctx, "custom", customError) if err == nil { t.Fatal(err) } @@ -58,9 +68,15 @@ func TestErrors(t *testing.T) { } else if c.key != "custom" { t.Errorf("customError contained wrong key %q, expecting %q", c.key, "custom") } + }) +} +func TestPanic(t *testing.T) { + t.Skip("weavertest crashes if any component panics, even in another process") + ctx := context.Background() + weavertest.Multi.Test(t, func(t *testing.T, client testApp) { // Trigger a panic. - _, err = client.Get(ctx, "foo", panicError) + _, err := client.Get(ctx, "foo", panicError) if err == nil || !errors.Is(err, weaver.RemoteCallError) { t.Fatalf("expected a weaver.RemoteCallError; got: %v", err) }