diff --git a/runners/portability/java/src/main/java/org/apache/beam/runners/portability/JobServicePipelineResult.java b/runners/portability/java/src/main/java/org/apache/beam/runners/portability/JobServicePipelineResult.java index 261017c83d6c2..9b286273a367a 100644 --- a/runners/portability/java/src/main/java/org/apache/beam/runners/portability/JobServicePipelineResult.java +++ b/runners/portability/java/src/main/java/org/apache/beam/runners/portability/JobServicePipelineResult.java @@ -17,182 +17,196 @@ */ package org.apache.beam.runners.portability; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; - +import java.util.Iterator; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; -import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import java.util.concurrent.atomic.AtomicReference; import org.apache.beam.model.jobmanagement.v1.JobApi; import org.apache.beam.model.jobmanagement.v1.JobApi.CancelJobRequest; import org.apache.beam.model.jobmanagement.v1.JobApi.CancelJobResponse; import org.apache.beam.model.jobmanagement.v1.JobApi.GetJobStateRequest; +import org.apache.beam.model.jobmanagement.v1.JobApi.JobMessage; +import org.apache.beam.model.jobmanagement.v1.JobApi.JobMessagesRequest; +import org.apache.beam.model.jobmanagement.v1.JobApi.JobMessagesResponse; import org.apache.beam.model.jobmanagement.v1.JobApi.JobStateEvent; import org.apache.beam.model.jobmanagement.v1.JobServiceGrpc.JobServiceBlockingStub; import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.metrics.MetricResults; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ListenableFuture; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ListeningScheduledExecutorService; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.MoreExecutors; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Duration; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +@SuppressWarnings({"keyfor", "nullness"}) // TODO(https://github.com/apache/beam/issues/20497) class JobServicePipelineResult implements PipelineResult, AutoCloseable { private static final long POLL_INTERVAL_MS = 3_000; private static final Logger LOG = LoggerFactory.getLogger(JobServicePipelineResult.class); - private final ListeningScheduledExecutorService executorService = - MoreExecutors.listeningDecorator(Executors.newSingleThreadScheduledExecutor()); - - private final String jobId; - private final JobServiceBlockingStub jobService; - private final AtomicReference latestState = new AtomicReference<>(State.UNKNOWN); - private final Runnable cleanup; - private final AtomicReference jobMetrics = - new AtomicReference<>(PortableMetrics.of(JobApi.MetricResults.getDefaultInstance())); - private CompletableFuture terminalStateFuture = new CompletableFuture<>(); - private CompletableFuture metricResultsCompletableFuture = - new CompletableFuture<>(); - JobServicePipelineResult(String jobId, JobServiceBlockingStub jobService, Runnable cleanup) { + private final ByteString jobId; + private final int jobServerTimeout; + private final CloseableResource jobService; + private @Nullable State terminalState; + private final @Nullable Runnable cleanup; + private org.apache.beam.model.jobmanagement.v1.JobApi.MetricResults jobMetrics; + + JobServicePipelineResult( + ByteString jobId, + int jobServerTimeout, + CloseableResource jobService, + Runnable cleanup) { this.jobId = jobId; + this.jobServerTimeout = jobServerTimeout; this.jobService = jobService; + this.terminalState = null; this.cleanup = cleanup; } @Override public State getState() { - if (latestState.get().isTerminal()) { - return latestState.get(); + if (terminalState != null) { + return terminalState; } + JobServiceBlockingStub stub = + jobService.get().withDeadlineAfter(jobServerTimeout, TimeUnit.SECONDS); JobStateEvent response = - jobService.getState(GetJobStateRequest.newBuilder().setJobId(jobId).build()); - State state = State.valueOf(response.getState().name()); - latestState.set(state); - return state; + stub.getState(GetJobStateRequest.newBuilder().setJobIdBytes(jobId).build()); + return getJavaState(response.getState()); } @Override public State cancel() { - if (latestState.get().isTerminal()) { - return latestState.get(); - } + JobServiceBlockingStub stub = jobService.get(); CancelJobResponse response = - jobService.cancel(CancelJobRequest.newBuilder().setJobId(jobId).build()); - State state = State.valueOf(response.getState().name()); - latestState.set(state); - return state; + stub.cancel(CancelJobRequest.newBuilder().setJobIdBytes(jobId).build()); + return getJavaState(response.getState()); } + @Nullable @Override public State waitUntilFinish(Duration duration) { - if (latestState.get().isTerminal()) { - return latestState.get(); - } - try { - return pollForTerminalState().get(duration.getMillis(), TimeUnit.MILLISECONDS); - } catch (InterruptedException | ExecutionException | TimeoutException e) { - throw new RuntimeException(e); + if (duration.compareTo(Duration.millis(1)) <= 0) { + // Equivalent to infinite timeout. + return waitUntilFinish(); + } else { + CompletableFuture result = CompletableFuture.supplyAsync(this::waitUntilFinish); + try { + return result.get(duration.getMillis(), TimeUnit.MILLISECONDS); + } catch (TimeoutException e) { + // Null result indicates a timeout. + return null; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } catch (ExecutionException e) { + throw new RuntimeException(e); + } } } @Override public State waitUntilFinish() { - if (latestState.get().isTerminal()) { - return latestState.get(); + if (terminalState != null) { + return terminalState; } try { - return pollForTerminalState().get(); - } catch (InterruptedException | ExecutionException e) { - throw new RuntimeException(e); + waitForTerminalState(); + propagateErrors(); + return terminalState; + } finally { + close(); } } - void setTerminalStateFuture(CompletableFuture terminalStateFuture) { - this.terminalStateFuture = terminalStateFuture; - } - - CompletableFuture getTerminalStateFuture() { - return this.terminalStateFuture; - } - - void setMetricResultsCompletableFuture( - CompletableFuture metricResultsCompletableFuture) { - this.metricResultsCompletableFuture = metricResultsCompletableFuture; - } - - CompletableFuture getMetricResultsCompletableFuture() { - return this.metricResultsCompletableFuture; + @Override + public MetricResults metrics() { + return PortableMetrics.of(jobMetrics); } - CompletableFuture pollForTerminalState() { - CompletableFuture completableFuture = new CompletableFuture<>(); - ListenableFuture future = - executorService.scheduleAtFixedRate( - () -> { - State state = getState(); - LOG.info("Job: {} latest state: {}", jobId, state); - latestState.set(state); - if (state.isTerminal()) { - completableFuture.complete(state); - } - }, - 0L, - POLL_INTERVAL_MS, - TimeUnit.MILLISECONDS); - return completableFuture.whenComplete( - (state, throwable) -> { - checkState( - state.isTerminal(), - "future should have completed with a terminal state, got: %s", - state); - future.cancel(true); - LOG.info("Job: {} reached terminal state: {}", jobId, state); - if (throwable != null) { - throw new RuntimeException(throwable); - } - }); + @Override + public void close() { + try (CloseableResource jobService = this.jobService) { + JobApi.GetJobMetricsRequest metricsRequest = + JobApi.GetJobMetricsRequest.newBuilder().setJobIdBytes(jobId).build(); + jobMetrics = jobService.get().getJobMetrics(metricsRequest).getMetrics(); + if (cleanup != null) { + cleanup.run(); + } + } catch (Exception e) { + LOG.warn("Error cleaning up job service", e); + } } - CompletableFuture pollForMetrics() { - CompletableFuture completableFuture = new CompletableFuture<>(); - ListenableFuture future = - executorService.scheduleAtFixedRate( - () -> { - if (latestState.get().isTerminal()) { - completableFuture.complete(jobMetrics.get()); - return; - } - JobApi.GetJobMetricsRequest metricsRequest = - JobApi.GetJobMetricsRequest.newBuilder().setJobId(jobId).build(); - JobApi.MetricResults results = jobService.getJobMetrics(metricsRequest).getMetrics(); - jobMetrics.set(PortableMetrics.of(results)); - }, - 0L, - 1L, - TimeUnit.SECONDS); - return completableFuture.whenComplete( - ((metricResults, throwable) -> { - checkState( - latestState.get().isTerminal(), - "future should have completed with a terminal state, got: %s", - latestState.get()); - future.cancel(true); - LOG.info("Job: {} latest metrics: {}", jobId, metricResults.toString()); - })); + private void waitForTerminalState() { + JobServiceBlockingStub stub = + jobService.get().withDeadlineAfter(jobServerTimeout, TimeUnit.SECONDS); + GetJobStateRequest request = GetJobStateRequest.newBuilder().setJobIdBytes(jobId).build(); + JobStateEvent response = stub.getState(request); + State lastState = getJavaState(response.getState()); + while (!lastState.isTerminal()) { + try { + Thread.sleep(POLL_INTERVAL_MS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + response = stub.withDeadlineAfter(jobServerTimeout, TimeUnit.SECONDS).getState(request); + lastState = getJavaState(response.getState()); + } + terminalState = lastState; } - @Override - public MetricResults metrics() { - return jobMetrics.get(); + private void propagateErrors() { + if (terminalState != State.DONE) { + JobMessagesRequest messageStreamRequest = + JobMessagesRequest.newBuilder().setJobIdBytes(jobId).build(); + Iterator messageStreamIterator = + jobService + .get() + .withDeadlineAfter(jobServerTimeout, TimeUnit.SECONDS) + .getMessageStream(messageStreamRequest); + while (messageStreamIterator.hasNext()) { + JobMessage messageResponse = messageStreamIterator.next().getMessageResponse(); + if (messageResponse.getImportance() == JobMessage.MessageImportance.JOB_MESSAGE_ERROR) { + throw new RuntimeException( + "The Runner experienced the following error during execution:\n" + + messageResponse.getMessageText()); + } + } + } } - @Override - public void close() { - cleanup.run(); + private static State getJavaState(JobApi.JobState.Enum protoState) { + switch (protoState) { + case UNSPECIFIED: + return State.UNKNOWN; + case STOPPED: + return State.STOPPED; + case RUNNING: + return State.RUNNING; + case DONE: + return State.DONE; + case FAILED: + return State.FAILED; + case CANCELLED: + return State.CANCELLED; + case UPDATED: + return State.UPDATED; + case DRAINING: + // TODO: Determine the correct mappings for the states below. + return State.UNKNOWN; + case DRAINED: + return State.UNKNOWN; + case STARTING: + return State.RUNNING; + case CANCELLING: + return State.CANCELLED; + default: + LOG.warn("Unrecognized state from server: {}", protoState); + return State.UNKNOWN; + } } } diff --git a/runners/portability/java/src/main/java/org/apache/beam/runners/portability/PortableRunner.java b/runners/portability/java/src/main/java/org/apache/beam/runners/portability/PortableRunner.java index c967e8e5daac4..000efb7430af3 100644 --- a/runners/portability/java/src/main/java/org/apache/beam/runners/portability/PortableRunner.java +++ b/runners/portability/java/src/main/java/org/apache/beam/runners/portability/PortableRunner.java @@ -35,6 +35,7 @@ import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.runners.fnexecution.artifact.ArtifactRetrievalService; import org.apache.beam.runners.fnexecution.artifact.ArtifactStagingService; +import org.apache.beam.runners.portability.CloseableResource.CloseException; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.PipelineRunner; @@ -49,6 +50,7 @@ import org.apache.beam.sdk.util.construction.PipelineOptionsTranslation; import org.apache.beam.sdk.util.construction.PipelineTranslation; import org.apache.beam.sdk.util.construction.SdkComponents; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; @@ -99,7 +101,7 @@ private PortableRunner( @Override public PipelineResult run(Pipeline pipeline) { - Runnable cleanup = () -> {}; + Runnable cleanup; if (Environments.ENVIRONMENT_LOOPBACK.equals( options.as(PortablePipelineOptions.class).getDefaultEnvironmentType())) { GrpcFnServer workerService; @@ -121,6 +123,8 @@ public PipelineResult run(Pipeline pipeline) { throw new RuntimeException(exn); } }; + } else { + cleanup = null; } ImmutableList.Builder filesToStageBuilder = ImmutableList.builder(); @@ -170,47 +174,55 @@ public PipelineResult run(Pipeline pipeline) { ManagedChannel jobServiceChannel = channelFactory.forDescriptor(ApiServiceDescriptor.newBuilder().setUrl(endpoint).build()); - int jobServerTimeout = options.as(PortablePipelineOptions.class).getJobServerTimeout(); - JobServiceBlockingStub jobService = - JobServiceGrpc.newBlockingStub(jobServiceChannel) - .withDeadlineAfter(jobServerTimeout, TimeUnit.SECONDS); - - PrepareJobResponse prepareJobResponse = jobService.prepare(prepareJobRequest); - LOG.info("PrepareJobResponse: {}", prepareJobResponse); - - ApiServiceDescriptor artifactStagingEndpoint = prepareJobResponse.getArtifactStagingEndpoint(); - String stagingSessionToken = prepareJobResponse.getStagingSessionToken(); - - try (CloseableResource artifactChannel = - CloseableResource.of( - channelFactory.forDescriptor(artifactStagingEndpoint), ManagedChannel::shutdown)) { - - ArtifactStagingService.offer( - new ArtifactRetrievalService(), - ArtifactStagingServiceGrpc.newStub(artifactChannel.get()), - stagingSessionToken); - } catch (CloseableResource.CloseException e) { - LOG.warn("Error closing artifact staging channel", e); - // CloseExceptions should only be thrown while closing the channel. - } catch (Exception e) { - throw new RuntimeException("Error staging files.", e); - } + JobServiceBlockingStub jobService = JobServiceGrpc.newBlockingStub(jobServiceChannel); + try (CloseableResource wrappedJobService = + CloseableResource.of(jobService, unused -> jobServiceChannel.shutdown())) { + + final int jobServerTimeout = options.as(PortablePipelineOptions.class).getJobServerTimeout(); + PrepareJobResponse prepareJobResponse = + jobService + .withDeadlineAfter(jobServerTimeout, TimeUnit.SECONDS) + .withWaitForReady() + .prepare(prepareJobRequest); + LOG.info("PrepareJobResponse: {}", prepareJobResponse); + + ApiServiceDescriptor artifactStagingEndpoint = + prepareJobResponse.getArtifactStagingEndpoint(); + String stagingSessionToken = prepareJobResponse.getStagingSessionToken(); + + try (CloseableResource artifactChannel = + CloseableResource.of( + channelFactory.forDescriptor(artifactStagingEndpoint), ManagedChannel::shutdown)) { + + ArtifactStagingService.offer( + new ArtifactRetrievalService(), + ArtifactStagingServiceGrpc.newStub(artifactChannel.get()), + stagingSessionToken); + } catch (CloseableResource.CloseException e) { + LOG.warn("Error closing artifact staging channel", e); + // CloseExceptions should only be thrown while closing the channel. + } catch (Exception e) { + throw new RuntimeException("Error staging files.", e); + } - RunJobRequest runJobRequest = - RunJobRequest.newBuilder().setPreparationId(prepareJobResponse.getPreparationId()).build(); + RunJobRequest runJobRequest = + RunJobRequest.newBuilder() + .setPreparationId(prepareJobResponse.getPreparationId()) + .build(); - // Run the job and wait for a result, we don't set a timeout here because - // it may take a long time for a job to complete and streaming - // jobs never return a response. - RunJobResponse runJobResponse = jobService.run(runJobRequest); + // Run the job and wait for a result, we don't set a timeout here because + // it may take a long time for a job to complete and streaming + // jobs never return a response. + RunJobResponse runJobResponse = jobService.run(runJobRequest); - LOG.info("RunJobResponse: {}", runJobResponse); - String jobId = runJobResponse.getJobId(); + LOG.info("RunJobResponse: {}", runJobResponse); + ByteString jobId = runJobResponse.getJobIdBytes(); - JobServicePipelineResult result = new JobServicePipelineResult(jobId, jobService, cleanup); - result.setTerminalStateFuture(result.pollForTerminalState()); - result.setMetricResultsCompletableFuture(result.pollForMetrics()); - return result; + return new JobServicePipelineResult( + jobId, jobServerTimeout, wrappedJobService.transfer(), cleanup); + } catch (CloseException e) { + throw new RuntimeException(e); + } } @Override diff --git a/runners/portability/java/src/test/java/org/apache/beam/runners/portability/PortableRunnerTest.java b/runners/portability/java/src/test/java/org/apache/beam/runners/portability/PortableRunnerTest.java index f89a55b00e8dd..788d4a43319d6 100644 --- a/runners/portability/java/src/test/java/org/apache/beam/runners/portability/PortableRunnerTest.java +++ b/runners/portability/java/src/test/java/org/apache/beam/runners/portability/PortableRunnerTest.java @@ -57,7 +57,6 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams; import org.joda.time.Duration; import org.joda.time.Instant; -import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -106,7 +105,6 @@ public void stagesAndRunsJob() throws Exception { assertThat(state, is(State.DONE)); } - @Ignore @Test public void extractsMetrics() throws Exception { JobApi.MetricResults metricResults = generateMetricResults();