Skip to content

Commit

Permalink
[ML] retry bulk indexing of state docs (elastic#50149)
Browse files Browse the repository at this point in the history
This exchanges the direct use of the `Client` for `ResultsPersisterService`. State doc persistence will now retry. Failures to persist state will still not throw, but will be audited and logged.
  • Loading branch information
benwtrent authored and SivagurunathanV committed Jan 21, 2020
1 parent fdf820d commit 1c2b083
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -559,11 +559,17 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
environment,
settings,
nativeController,
client,
clusterService);
clusterService,
resultsPersisterService,
anomalyDetectionAuditor);
normalizerProcessFactory = new NativeNormalizerProcessFactory(environment, nativeController, clusterService);
analyticsProcessFactory = new NativeAnalyticsProcessFactory(environment, client, nativeController, clusterService,
xContentRegistry);
analyticsProcessFactory = new NativeAnalyticsProcessFactory(
environment,
nativeController,
clusterService,
xContentRegistry,
resultsPersisterService,
dataFrameAnalyticsAuditor);
memoryEstimationProcessFactory =
new NativeMemoryUsageEstimationProcessFactory(environment, nativeController, clusterService);
mlController = nativeController;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.bytes.BytesReference;
Expand All @@ -20,10 +19,12 @@
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.elasticsearch.xpack.ml.process.IndexingStateProcessor;
import org.elasticsearch.xpack.ml.process.NativeController;
import org.elasticsearch.xpack.ml.process.ProcessPipes;
import org.elasticsearch.xpack.ml.utils.NamedPipeHelper;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;

import java.io.IOException;
import java.nio.file.Path;
Expand All @@ -40,18 +41,24 @@ public class NativeAnalyticsProcessFactory implements AnalyticsProcessFactory<An

private static final NamedPipeHelper NAMED_PIPE_HELPER = new NamedPipeHelper();

private final Client client;
private final Environment env;
private final NativeController nativeController;
private final NamedXContentRegistry namedXContentRegistry;
private final ResultsPersisterService resultsPersisterService;
private final DataFrameAnalyticsAuditor auditor;
private volatile Duration processConnectTimeout;

public NativeAnalyticsProcessFactory(Environment env, Client client, NativeController nativeController, ClusterService clusterService,
NamedXContentRegistry namedXContentRegistry) {
public NativeAnalyticsProcessFactory(Environment env,
NativeController nativeController,
ClusterService clusterService,
NamedXContentRegistry namedXContentRegistry,
ResultsPersisterService resultsPersisterService,
DataFrameAnalyticsAuditor auditor) {
this.env = Objects.requireNonNull(env);
this.client = Objects.requireNonNull(client);
this.nativeController = Objects.requireNonNull(nativeController);
this.namedXContentRegistry = Objects.requireNonNull(namedXContentRegistry);
this.auditor = auditor;
this.resultsPersisterService = resultsPersisterService;
setProcessConnectTimeout(MachineLearning.PROCESS_CONNECT_TIMEOUT.get(env.settings()));
clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.PROCESS_CONNECT_TIMEOUT,
this::setProcessConnectTimeout);
Expand Down Expand Up @@ -96,7 +103,7 @@ public NativeAnalyticsProcess createAnalyticsProcess(DataFrameAnalyticsConfig co
private void startProcess(DataFrameAnalyticsConfig config, ExecutorService executorService, ProcessPipes processPipes,
NativeAnalyticsProcess process) {
if (config.getAnalysis().persistsState()) {
IndexingStateProcessor stateProcessor = new IndexingStateProcessor(client, config.getId());
IndexingStateProcessor stateProcessor = new IndexingStateProcessor(config.getId(), resultsPersisterService, auditor);
process.start(executorService, stateProcessor, processPipes.getPersistStream().get());
} else {
process.start(executorService);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
Expand All @@ -20,11 +19,13 @@
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.job.process.autodetect.params.AutodetectParams;
import org.elasticsearch.xpack.ml.job.results.AutodetectResult;
import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor;
import org.elasticsearch.xpack.ml.process.IndexingStateProcessor;
import org.elasticsearch.xpack.ml.process.NativeController;
import org.elasticsearch.xpack.ml.process.ProcessPipes;
import org.elasticsearch.xpack.ml.process.ProcessResultsParser;
import org.elasticsearch.xpack.ml.utils.NamedPipeHelper;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;

import java.io.IOException;
import java.nio.file.Path;
Expand All @@ -40,20 +41,26 @@ public class NativeAutodetectProcessFactory implements AutodetectProcessFactory
private static final Logger LOGGER = LogManager.getLogger(NativeAutodetectProcessFactory.class);
private static final NamedPipeHelper NAMED_PIPE_HELPER = new NamedPipeHelper();

private final Client client;
private final Environment env;
private final Settings settings;
private final NativeController nativeController;
private final ClusterService clusterService;
private final ResultsPersisterService resultsPersisterService;
private final AnomalyDetectionAuditor auditor;
private volatile Duration processConnectTimeout;

public NativeAutodetectProcessFactory(Environment env, Settings settings, NativeController nativeController, Client client,
ClusterService clusterService) {
public NativeAutodetectProcessFactory(Environment env,
Settings settings,
NativeController nativeController,
ClusterService clusterService,
ResultsPersisterService resultsPersisterService,
AnomalyDetectionAuditor auditor) {
this.env = Objects.requireNonNull(env);
this.settings = Objects.requireNonNull(settings);
this.nativeController = Objects.requireNonNull(nativeController);
this.client = client;
this.clusterService = clusterService;
this.resultsPersisterService = resultsPersisterService;
this.auditor = auditor;
setProcessConnectTimeout(MachineLearning.PROCESS_CONNECT_TIMEOUT.get(settings));
clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.PROCESS_CONNECT_TIMEOUT,
this::setProcessConnectTimeout);
Expand All @@ -78,7 +85,7 @@ public AutodetectProcess createAutodetectProcess(Job job,
// The extra 1 is the control field
int numberOfFields = job.allInputFields().size() + (includeTokensField ? 1 : 0) + 1;

IndexingStateProcessor stateProcessor = new IndexingStateProcessor(client, job.getId());
IndexingStateProcessor stateProcessor = new IndexingStateProcessor(job.getId(), resultsPersisterService, auditor);
ProcessResultsParser<AutodetectResult> resultsParser = new ProcessResultsParser<>(AutodetectResult.PARSER,
NamedXContentRegistry.EMPTY);
NativeAutodetectProcess autodetect = new NativeAutodetectProcess(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,22 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.bytes.CompositeBytesReference;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.xpack.core.common.notifications.AbstractAuditMessage;
import org.elasticsearch.xpack.core.common.notifications.AbstractAuditor;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;

import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;

import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;

/**
* Reads state documents of a stream, splits them and persists to an index via a bulk request
Expand All @@ -32,12 +33,16 @@ public class IndexingStateProcessor implements StateProcessor {

private static final int READ_BUF_SIZE = 8192;

private final Client client;
private final String jobId;
private final AbstractAuditor<? extends AbstractAuditMessage> auditor;
private final ResultsPersisterService resultsPersisterService;

public IndexingStateProcessor(Client client, String jobId) {
this.client = client;
public IndexingStateProcessor(String jobId,
ResultsPersisterService resultsPersisterService,
AbstractAuditor<? extends AbstractAuditMessage> auditor) {
this.jobId = jobId;
this.resultsPersisterService = resultsPersisterService;
this.auditor = auditor;
}

@Override
Expand Down Expand Up @@ -98,8 +103,15 @@ void persist(BytesReference bytes) throws IOException {
bulkRequest.add(bytes, AnomalyDetectorsIndex.jobStateIndexWriteAlias(), XContentType.JSON);
if (bulkRequest.numberOfActions() > 0) {
LOGGER.trace("[{}] Persisting job state document", jobId);
try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
client.bulk(bulkRequest).actionGet();
try {
resultsPersisterService.bulkIndexWithRetry(bulkRequest,
jobId,
() -> true,
(msg) -> auditor.warning(jobId, "Bulk indexing of state failed " + msg));
} catch (Exception ex) {
String msg = "failed indexing updated state docs";
LOGGER.error(() -> new ParameterizedMessage("[{}] {}", jobId, msg), ex);
auditor.error(jobId, msg + " error: " + ex.getMessage());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
*/
package org.elasticsearch.xpack.ml.job.process.autodetect;

import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
Expand All @@ -16,8 +15,10 @@
import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.job.process.autodetect.params.AutodetectParams;
import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor;
import org.elasticsearch.xpack.ml.process.NativeController;
import org.elasticsearch.xpack.ml.process.ProcessPipes;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;

import java.io.IOException;
import java.time.Duration;
Expand All @@ -41,7 +42,8 @@ public void testSetProcessConnectTimeout() throws IOException {
.build();
Environment env = TestEnvironment.newEnvironment(settings);
NativeController nativeController = mock(NativeController.class);
Client client = mock(Client.class);
ResultsPersisterService resultsPersisterService = mock(ResultsPersisterService.class);
AnomalyDetectionAuditor anomalyDetectionAuditor = mock(AnomalyDetectionAuditor.class);
ClusterSettings clusterSettings = new ClusterSettings(settings,
Set.of(MachineLearning.PROCESS_CONNECT_TIMEOUT, AutodetectBuilder.MAX_ANOMALY_RECORDS_SETTING_DYNAMIC));
ClusterService clusterService = mock(ClusterService.class);
Expand All @@ -51,8 +53,13 @@ public void testSetProcessConnectTimeout() throws IOException {
AutodetectParams autodetectParams = mock(AutodetectParams.class);
ProcessPipes processPipes = mock(ProcessPipes.class);

NativeAutodetectProcessFactory nativeAutodetectProcessFactory =
new NativeAutodetectProcessFactory(env, settings, nativeController, client, clusterService);
NativeAutodetectProcessFactory nativeAutodetectProcessFactory = new NativeAutodetectProcessFactory(
env,
settings,
nativeController,
clusterService,
resultsPersisterService,
anomalyDetectionAuditor);
nativeAutodetectProcessFactory.setProcessConnectTimeout(TimeValue.timeValueSeconds(timeoutSeconds));
nativeAutodetectProcessFactory.createNativeProcess(job, autodetectParams, processPipes, Collections.emptyList());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
package org.elasticsearch.xpack.ml.process;

import com.carrotsearch.randomizedtesting.annotations.Timeout;
import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.mock.orig.Mockito;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
import org.junit.After;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
Expand Down Expand Up @@ -54,24 +54,22 @@ public class IndexingStateProcessorTests extends ESTestCase {
private static final int NUM_LARGE_DOCS = 2;
private static final int LARGE_DOC_SIZE = 1000000;

private Client client;
private IndexingStateProcessor stateProcessor;
private ResultsPersisterService resultsPersisterService;

@Before
public void initialize() throws IOException {
client = mock(Client.class);
@SuppressWarnings("unchecked")
ActionFuture<BulkResponse> bulkResponseFuture = mock(ActionFuture.class);
stateProcessor = spy(new IndexingStateProcessor(client, JOB_ID));
when(client.bulk(any(BulkRequest.class))).thenReturn(bulkResponseFuture);
public void initialize() {
resultsPersisterService = mock(ResultsPersisterService.class);
AnomalyDetectionAuditor auditor = mock(AnomalyDetectionAuditor.class);
stateProcessor = spy(new IndexingStateProcessor(JOB_ID, resultsPersisterService, auditor));
when(resultsPersisterService.bulkIndexWithRetry(any(BulkRequest.class), any(), any(), any())).thenReturn(mock(BulkResponse.class));
ThreadPool threadPool = mock(ThreadPool.class);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
}

@After
public void verifyNoMoreClientInteractions() {
Mockito.verifyNoMoreInteractions(client);
Mockito.verifyNoMoreInteractions(resultsPersisterService);
}

public void testStateRead() throws IOException {
Expand All @@ -85,8 +83,7 @@ public void testStateRead() throws IOException {
assertEquals(threeStates[0], capturedBytes.get(0).utf8ToString());
assertEquals(threeStates[1], capturedBytes.get(1).utf8ToString());
assertEquals(threeStates[2], capturedBytes.get(2).utf8ToString());
verify(client, times(3)).bulk(any(BulkRequest.class));
verify(client, times(3)).threadPool();
verify(resultsPersisterService, times(3)).bulkIndexWithRetry(any(BulkRequest.class), any(), any(), any());
}

public void testStateReadGivenConsecutiveZeroBytes() throws IOException {
Expand All @@ -96,7 +93,7 @@ public void testStateReadGivenConsecutiveZeroBytes() throws IOException {
stateProcessor.process(stream);

verify(stateProcessor, never()).persist(any());
Mockito.verifyNoMoreInteractions(client);
Mockito.verifyNoMoreInteractions(resultsPersisterService);
}

public void testStateReadGivenConsecutiveSpacesFollowedByZeroByte() throws IOException {
Expand All @@ -106,7 +103,7 @@ public void testStateReadGivenConsecutiveSpacesFollowedByZeroByte() throws IOExc
stateProcessor.process(stream);

verify(stateProcessor, times(1)).persist(any());
Mockito.verifyNoMoreInteractions(client);
Mockito.verifyNoMoreInteractions(resultsPersisterService);
}

/**
Expand All @@ -128,7 +125,6 @@ public void testLargeStateRead() throws Exception {
ByteArrayInputStream stream = new ByteArrayInputStream(builder.toString().getBytes(StandardCharsets.UTF_8));
stateProcessor.process(stream);
verify(stateProcessor, times(NUM_LARGE_DOCS)).persist(any());
verify(client, times(NUM_LARGE_DOCS)).bulk(any(BulkRequest.class));
verify(client, times(NUM_LARGE_DOCS)).threadPool();
verify(resultsPersisterService, times(NUM_LARGE_DOCS)).bulkIndexWithRetry(any(BulkRequest.class), any(), any(), any());
}
}

0 comments on commit 1c2b083

Please sign in to comment.