diff --git a/sdks/go/pkg/beam/core/runtime/harness/harness.go b/sdks/go/pkg/beam/core/runtime/harness/harness.go index 02d82b3a513f..df6950e9e645 100644 --- a/sdks/go/pkg/beam/core/runtime/harness/harness.go +++ b/sdks/go/pkg/beam/core/runtime/harness/harness.go @@ -53,7 +53,7 @@ func Main(ctx context.Context, loggingEndpoint, controlEndpoint string, options case StatusAddress: statusEndpoint = string(option) default: - return errors.Errorf("unkown type %T, value %v in error call", option, option) + return errors.Errorf("unknown type %T, value %v in error call", option, option) } } @@ -115,12 +115,12 @@ func Main(ctx context.Context, loggingEndpoint, controlEndpoint string, options if statusEndpoint != "" { statusHandler, err := newWorkerStatusHandler(ctx, statusEndpoint) if err != nil { - log.Error(ctx, err) + log.Errorf(ctx, "error establishing connection to worker status API: %v", err) + } else { + statusHandler.wg.Add(1) + statusHandler.start(ctx) + defer statusHandler.stop(ctx) } - var swg sync.WaitGroup - swg.Add(1) - statusHandler.handleRequest(ctx, &swg) - defer statusHandler.close(ctx, &swg) } sideCache := statecache.SideInputCache{} diff --git a/sdks/go/pkg/beam/core/runtime/harness/worker_status.go b/sdks/go/pkg/beam/core/runtime/harness/worker_status.go index da85a04f351f..bb7bfa4f246d 100644 --- a/sdks/go/pkg/beam/core/runtime/harness/worker_status.go +++ b/sdks/go/pkg/beam/core/runtime/harness/worker_status.go @@ -30,41 +30,42 @@ import ( // workerStatusHandler stores the communication information of WorkerStatus API. type workerStatusHandler struct { - conn *grpc.ClientConn - shutdown int32 + conn *grpc.ClientConn + shouldShutdown int32 + wg sync.WaitGroup } func newWorkerStatusHandler(ctx context.Context, endpoint string) (*workerStatusHandler, error) { sconn, err := dial(ctx, endpoint, 60*time.Second) if err != nil { - return &workerStatusHandler{}, errors.Wrapf(err, "failed to connect: %v\n", endpoint) + return nil, errors.Wrapf(err, "failed to connect: %v\n", endpoint) } - return &workerStatusHandler{conn: sconn, shutdown: 0}, nil + return &workerStatusHandler{conn: sconn, shouldShutdown: 0}, nil } func (w *workerStatusHandler) isAlive() bool { - return atomic.LoadInt32(&w.shutdown) == 0 + return atomic.LoadInt32(&w.shouldShutdown) == 0 } -func (w *workerStatusHandler) stop() { - atomic.StoreInt32(&w.shutdown, 1) +func (w *workerStatusHandler) shutdown() { + atomic.StoreInt32(&w.shouldShutdown, 1) } -// handleRequest manages the WorkerStatus API. -func (w *workerStatusHandler) handleRequest(ctx context.Context, wg *sync.WaitGroup) { +// start starts the reader to accept WorkerStatusRequest and send WorkerStatusResponse with WorkerStatus API. +func (w *workerStatusHandler) start(ctx context.Context) { statusClient := fnpb.NewBeamFnWorkerStatusClient(w.conn) stub, err := statusClient.WorkerStatus(ctx) if err != nil { log.Errorf(ctx, "status client not established: %v", err) return } - go w.reader(ctx, stub, wg) + go w.reader(ctx, stub) } // reader reads the WorkerStatusRequest from the stream and sends a processed WorkerStatusResponse to // a response channel. -func (w *workerStatusHandler) reader(ctx context.Context, stub fnpb.BeamFnWorkerStatus_WorkerStatusClient, wg *sync.WaitGroup) { - defer wg.Done() +func (w *workerStatusHandler) reader(ctx context.Context, stub fnpb.BeamFnWorkerStatus_WorkerStatusClient) { + defer w.wg.Done() buf := make([]byte, 1<<16) for w.isAlive() { req, err := stub.Recv() @@ -81,10 +82,10 @@ func (w *workerStatusHandler) reader(ctx context.Context, stub fnpb.BeamFnWorker } } -// close stops the reader first, closes the response channel thereby stopping writer and finally closes the gRPC connection. -func (w *workerStatusHandler) close(ctx context.Context, wg *sync.WaitGroup) { - w.stop() - wg.Wait() +// stop stops the reader and closes worker status endpoint connection with the runner. +func (w *workerStatusHandler) stop(ctx context.Context) { + w.shutdown() + w.wg.Wait() if err := w.conn.Close(); err != nil { log.Errorf(ctx, "error closing status endpoint connection: %v", err) } diff --git a/sdks/go/pkg/beam/core/runtime/harness/worker_status_test.go b/sdks/go/pkg/beam/core/runtime/harness/worker_status_test.go index 4fa04c8739c6..57060b20f890 100644 --- a/sdks/go/pkg/beam/core/runtime/harness/worker_status_test.go +++ b/sdks/go/pkg/beam/core/runtime/harness/worker_status_test.go @@ -19,7 +19,6 @@ import ( "fmt" "log" "net" - "sync" "testing" fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" @@ -47,10 +46,8 @@ const buffsize = 1024 * 1024 var lis *bufconn.Listener -func setup(srv *BeamFnWorkerStatusServicer) { - +func setup(t *testing.T, srv *BeamFnWorkerStatusServicer) { server := grpc.NewServer() - lis = bufconn.Listen(buffsize) fnpb.RegisterBeamFnWorkerStatusServer(server, srv) go func() { @@ -58,6 +55,9 @@ func setup(srv *BeamFnWorkerStatusServicer) { log.Fatalf("failed to serve: %v", err) } }() + t.Cleanup(func() { + server.Stop() + }) } func dialer(context.Context, string) (net.Conn, error) { @@ -67,17 +67,16 @@ func dialer(context.Context, string) (net.Conn, error) { func TestSendStatusResponse(t *testing.T) { ctx := context.Background() srv := &BeamFnWorkerStatusServicer{response: make(chan string)} - setup(srv) + setup(t, srv) conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(dialer), grpc.WithInsecure()) if err != nil { t.Fatalf("unable to start test server: %v", err) } statusHandler := workerStatusHandler{conn: conn} - var wg sync.WaitGroup - wg.Add(1) - statusHandler.handleRequest(ctx, &wg) + statusHandler.wg.Add(1) + statusHandler.start(ctx) t.Cleanup(func() { - statusHandler.close(ctx, &wg) + statusHandler.stop(ctx) }) response := []string{} response = append(response, <-srv.response)