diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java index d84c91bdb8..94c1f79511 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java @@ -62,9 +62,15 @@ public void open() { /** Cancel a statement. */ public void cancel() { - if (statementModel.getStatementState().equals(StatementState.RUNNING)) { + StatementState statementState = statementModel.getStatementState(); + + if (statementState.equals(StatementState.SUCCESS) + || statementState.equals(StatementState.FAILED) + || statementState.equals(StatementState.CANCELLED)) { String errorMsg = - String.format("can't cancel statement in waiting state. statement: %s.", statementId); + String.format( + "can't cancel statement in %s state. statement: %s.", + statementState.getState(), statementId); LOG.error(errorMsg); throw new IllegalStateException(errorMsg); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index 1e33c8a6b9..29020f2496 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -8,6 +8,7 @@ import static org.opensearch.sql.spark.execution.session.InteractiveSessionTest.createSessionRequest; import static org.opensearch.sql.spark.execution.session.SessionManagerTest.sessionSetting; import static org.opensearch.sql.spark.execution.statement.StatementState.CANCELLED; +import static org.opensearch.sql.spark.execution.statement.StatementState.RUNNING; import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; import static org.opensearch.sql.spark.execution.statement.StatementTest.TestStatement.testStatement; import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; @@ -168,38 +169,93 @@ public void cancelFailedBecauseOfConflict() { } @Test - public void cancelRunningStatementFailed() { + public void cancelSuccessStatementFailed() { StatementId stId = new StatementId("statementId"); - Statement st = - Statement.builder() - .sessionId(new SessionId("sessionId")) - .applicationId("appId") - .jobId("jobId") - .statementId(stId) - .langType(LangType.SQL) - .datasourceName(DS_NAME) - .query("query") - .queryId("statementId") - .stateStore(stateStore) - .build(); - st.open(); + Statement st = createStatement(stId); + + // update to running state + StatementModel model = st.getStatementModel(); + st.setStatementModel( + StatementModel.copyWithState( + st.getStatementModel(), + StatementState.SUCCESS, + model.getSeqNo(), + model.getPrimaryTerm())); + + // cancel conflict + IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); + assertEquals( + String.format("can't cancel statement in success state. statement: %s.", stId), + exception.getMessage()); + } + + @Test + public void cancelFailedStatementFailed() { + StatementId stId = new StatementId("statementId"); + Statement st = createStatement(stId); // update to running state StatementModel model = st.getStatementModel(); st.setStatementModel( StatementModel.copyWithState( st.getStatementModel(), - StatementState.RUNNING, + StatementState.FAILED, model.getSeqNo(), model.getPrimaryTerm())); // cancel conflict IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); assertEquals( - String.format("can't cancel statement in waiting state. statement: %s.", stId), + String.format("can't cancel statement in failed state. statement: %s.", stId), + exception.getMessage()); + } + + @Test + public void cancelCancelledStatementFailed() { + StatementId stId = new StatementId("statementId"); + Statement st = createStatement(stId); + + // update to running state + StatementModel model = st.getStatementModel(); + st.setStatementModel( + StatementModel.copyWithState( + st.getStatementModel(), CANCELLED, model.getSeqNo(), model.getPrimaryTerm())); + + // cancel conflict + IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); + assertEquals( + String.format("can't cancel statement in cancelled state. statement: %s.", stId), exception.getMessage()); } + @Test + public void cancelRunningStatementSuccess() { + Statement st = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") + .statementId(new StatementId("statementId")) + .langType(LangType.SQL) + .datasourceName(DS_NAME) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + + // submit statement + TestStatement testStatement = testStatement(st, stateStore); + testStatement + .open() + .assertSessionState(WAITING) + .assertStatementId(new StatementId("statementId")); + + testStatement.run(); + + // close statement + testStatement.cancel().assertSessionState(CANCELLED); + } + @Test public void submitStatementInRunningSession() { Session session = @@ -355,9 +411,33 @@ public TestStatement cancel() { st.cancel(); return this; } + + public TestStatement run() { + StatementModel model = + updateStatementState(stateStore, DS_NAME).apply(st.getStatementModel(), RUNNING); + st.setStatementModel(model); + return this; + } } private QueryRequest queryRequest() { return new QueryRequest(AsyncQueryId.newAsyncQueryId(DS_NAME), LangType.SQL, "select 1"); } + + private Statement createStatement(StatementId stId) { + Statement st = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") + .statementId(stId) + .langType(LangType.SQL) + .datasourceName(DS_NAME) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + st.open(); + return st; + } }