diff --git a/components/audioinput/client.go b/components/audioinput/client.go index 6be08a4308f..6ef4426b851 100644 --- a/components/audioinput/client.go +++ b/components/audioinput/client.go @@ -23,15 +23,13 @@ import ( type client struct { resource.Named resource.TriviallyReconfigurable - resource.TriviallyCloseable conn rpc.ClientConn client pb.AudioInputServiceClient logger logging.Logger mu sync.Mutex name string activeBackgroundWorkers sync.WaitGroup - cancelCtx context.Context - cancel func() + healthyClientCh chan struct{} } // NewClientFromConn constructs a new Client from connection passed in. @@ -42,19 +40,13 @@ func NewClientFromConn( name resource.Name, logger logging.Logger, ) (AudioInput, error) { - // TODO(RSDK-6340): This client might still try to create audio streams after this - // context is canceled. These subsequent audio streams will not work. To fix this, - // use a channel instead of a context like we do in `component/audioinput/client.go` - cancelCtx, cancel := context.WithCancel(context.Background()) c := pb.NewAudioInputServiceClient(conn) return &client{ - Named: name.PrependRemote(remoteName).AsNamed(), - name: name.ShortName(), - conn: conn, - client: c, - logger: logger, - cancelCtx: cancelCtx, - cancel: cancel, + Named: name.PrependRemote(remoteName).AsNamed(), + name: name.ShortName(), + conn: conn, + client: c, + logger: logger, }, nil } @@ -75,7 +67,36 @@ func (c *client) Stream( ctx context.Context, errHandlers ...gostream.ErrorHandler, ) (gostream.AudioStream, error) { - streamCtx, stream, chunkCh := gostream.NewMediaStreamForChannel[wave.Audio](c.cancelCtx) + // RSDK-6340: The resource manager closes remote resources when the underlying + // connection goes bad. However, when the connection is re-established, the client + // objects these resources represent are not re-initialized/marked "healthy". + // `healthyClientCh` helps track these transitions between healthy and unhealthy + // states. + // + // When a new `client.Stream()` is created we will either use the existing + // `healthyClientCh` or create a new one. + // + // The goroutine a `Stream()` method spins off will listen to its version of the + // `healthyClientCh` to be notified when the connection has died so it can gracefully + // terminate. + // + // When a connection becomes unhealthy, the resource manager will call `Close` on the + // audioinput client object. Closing the client will: + // 1. close its `client.healthyClientCh` channel + // 2. wait for existing "stream" goroutines to drain + // 3. nil out the `client.healthyClientCh` member variable + // + // New streams concurrent with closing cannot start until this drain completes. There + // will never be stream goroutines from the old "generation" running concurrently + // with those from the new "generation". + c.mu.Lock() + if c.healthyClientCh == nil { + c.healthyClientCh = make(chan struct{}) + } + healthyClientCh := c.healthyClientCh + c.mu.Unlock() + + streamCtx, stream, chunkCh := gostream.NewMediaStreamForChannel[wave.Audio](context.Background()) chunksClient, err := c.client.Chunks(ctx, &pb.ChunksRequest{ Name: c.name, @@ -152,6 +173,11 @@ func (c *client) Stream( select { case <-streamCtx.Done(): return + case <-healthyClientCh: + if err := stream.Close(context.Background()); err != nil { + c.logger.Warn("error closing stream", err) + } + return case chunkCh <- gostream.MediaReleasePairWithError[wave.Audio]{ Media: chunk, Release: func() {}, @@ -186,10 +212,19 @@ func (c *client) DoCommand(ctx context.Context, cmd map[string]interface{}) (map return protoutils.DoFromResourceClient(ctx, c.client, c.name, cmd) } +// TODO(RSDK-6433): This method can be called more than once during a client's lifecycle. +// For example, consider a case where a remote audioinput goes offline and then back +// online. We will call `Close` on the audioinput client when we detect the disconnection +// to remove active streams but then reuse the client when the connection is +// re-established. func (c *client) Close(ctx context.Context) error { c.mu.Lock() - c.cancel() - c.mu.Unlock() + defer c.mu.Unlock() + + if c.healthyClientCh != nil { + close(c.healthyClientCh) + } c.activeBackgroundWorkers.Wait() + c.healthyClientCh = nil return nil } diff --git a/components/audioinput/client_test.go b/components/audioinput/client_test.go index d1ece05bb25..6fe33dd9d9a 100644 --- a/components/audioinput/client_test.go +++ b/components/audioinput/client_test.go @@ -132,3 +132,92 @@ func TestClient(t *testing.T) { test.That(t, conn.Close(), test.ShouldBeNil) }) } + +func TestClientStreamAfterClose(t *testing.T) { + // Set up gRPC server + logger := logging.NewTestLogger(t) + listener, err := net.Listen("tcp", "localhost:0") + test.That(t, err, test.ShouldBeNil) + rpcServer, err := rpc.NewServer(logger.AsZap(), rpc.WithUnauthenticated()) + test.That(t, err, test.ShouldBeNil) + + // Set up audioinput that can stream audio + + audioData := &wave.Float32Interleaved{ + Data: []float32{ + 0.1, -0.5, 0.2, -0.6, 0.3, -0.7, 0.4, -0.8, 0.5, -0.9, 0.6, -1.0, 0.7, -1.1, 0.8, -1.2, + }, + Size: wave.ChunkInfo{8, 2, 48000}, + } + + injectAudioInput := &inject.AudioInput{} + + // good audio input + injectAudioInput.StreamFunc = func(ctx context.Context, errHandlers ...gostream.ErrorHandler) (gostream.AudioStream, error) { + return gostream.NewEmbeddedAudioStreamFromReader(gostream.AudioReaderFunc(func(ctx context.Context) (wave.Audio, func(), error) { + return audioData, func() {}, nil + })), nil + } + + expectedProps := prop.Audio{ + ChannelCount: 1, + SampleRate: 2, + IsBigEndian: true, + IsInterleaved: true, + Latency: 5, + } + injectAudioInput.MediaPropertiesFunc = func(ctx context.Context) (prop.Audio, error) { + return expectedProps, nil + } + + // Register AudioInputService API in our gRPC server. + resources := map[resource.Name]audioinput.AudioInput{ + audioinput.Named(testAudioInputName): injectAudioInput, + } + audioinputSvc, err := resource.NewAPIResourceCollection(audioinput.API, resources) + test.That(t, err, test.ShouldBeNil) + resourceAPI, ok, err := resource.LookupAPIRegistration[audioinput.AudioInput](audioinput.API) + test.That(t, err, test.ShouldBeNil) + test.That(t, ok, test.ShouldBeTrue) + test.That(t, resourceAPI.RegisterRPCService(context.Background(), rpcServer, audioinputSvc), test.ShouldBeNil) + + // Start serving requests. + go rpcServer.Serve(listener) + defer rpcServer.Stop() + + // Make client connection + conn, err := viamgrpc.Dial(context.Background(), listener.Addr().String(), logger) + test.That(t, err, test.ShouldBeNil) + client, err := audioinput.NewClientFromConn(context.Background(), conn, "", audioinput.Named(testAudioInputName), logger) + test.That(t, err, test.ShouldBeNil) + + // Get a stream + stream, err := client.Stream(context.Background()) + test.That(t, stream, test.ShouldNotBeNil) + test.That(t, err, test.ShouldBeNil) + + // Read from stream + media, _, err := stream.Next(context.Background()) + test.That(t, media, test.ShouldNotBeNil) + test.That(t, err, test.ShouldBeNil) + + // Close client and read from stream + test.That(t, client.Close(context.Background()), test.ShouldBeNil) + media, _, err = stream.Next(context.Background()) + test.That(t, media, test.ShouldBeNil) + test.That(t, err.Error(), test.ShouldContainSubstring, "context canceled") + + // Get a new stream + stream, err = client.Stream(context.Background()) + test.That(t, stream, test.ShouldNotBeNil) + test.That(t, err, test.ShouldBeNil) + + // Read from the new stream + media, _, err = stream.Next(context.Background()) + test.That(t, media, test.ShouldNotBeNil) + test.That(t, err, test.ShouldBeNil) + + // Close client and connection + test.That(t, client.Close(context.Background()), test.ShouldBeNil) + test.That(t, conn.Close(), test.ShouldBeNil) +}