Skip to content

Commit

Permalink
Merge branch '2.x' of github.com:opensearch-project/sql into backport…
Browse files Browse the repository at this point in the history
…/backport-2772-to-2.x
  • Loading branch information
derek-ho committed Jul 15, 2024
2 parents 5bc5146 + 9c2b01d commit 30c8237
Show file tree
Hide file tree
Showing 32 changed files with 421 additions and 355 deletions.
4 changes: 0 additions & 4 deletions .github/workflows/integ-tests-with-security.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ jobs:
# need to switch to root so that github actions can install runner binary on container without permission issues.
options: --user root

# Allow using Node16 actions
env:
ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true

steps:
- uses: actions/checkout@v3

Expand Down
4 changes: 0 additions & 4 deletions .github/workflows/sql-test-and-build-workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,6 @@ jobs:
# need to switch to root so that github actions can install runner binary on container without permission issues.
options: --user root

# Allow using Node16 actions
env:
ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true

steps:
- uses: actions/checkout@v3

Expand Down
1 change: 0 additions & 1 deletion async-query-core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ dependencies {
implementation project(':core')
implementation project(':spark') // TODO: dependency to spark should be eliminated
implementation project(':datasources') // TODO: dependency to datasources should be eliminated
implementation project(':legacy') // TODO: dependency to legacy should be eliminated
implementation 'org.json:json:20231013'
implementation 'com.google.code.gson:gson:2.8.9'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ public interface EMRServerlessClientFactory {
/**
* Gets an instance of {@link EMRServerlessClient}.
*
* @param accountId Account ID of the requester. It will be used to decide the cluster.
* @return An {@link EMRServerlessClient} instance.
*/
EMRServerlessClient getClient();
EMRServerlessClient getClient(String accountId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;
import org.opensearch.sql.spark.metrics.MetricsService;

/** Implementation of {@link EMRServerlessClientFactory}. */
@RequiredArgsConstructor
public class EMRServerlessClientFactoryImpl implements EMRServerlessClientFactory {

Expand All @@ -27,13 +26,8 @@ public class EMRServerlessClientFactoryImpl implements EMRServerlessClientFactor
private EMRServerlessClient emrServerlessClient;
private String region;

/**
* Gets an instance of {@link EMRServerlessClient}.
*
* @return An {@link EMRServerlessClient} instance.
*/
@Override
public EMRServerlessClient getClient() {
public EMRServerlessClient getClient(String accountId) {
SparkExecutionEngineConfig sparkExecutionEngineConfig =
this.sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(
new NullAsyncQueryRequestContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,27 @@ public class QueryHandlerFactory {
private final EMRServerlessClientFactory emrServerlessClientFactory;
private final MetricsService metricsService;

public RefreshQueryHandler getRefreshQueryHandler() {
public RefreshQueryHandler getRefreshQueryHandler(String accountId) {
return new RefreshQueryHandler(
emrServerlessClientFactory.getClient(),
emrServerlessClientFactory.getClient(accountId),
jobExecutionResponseReader,
flintIndexMetadataService,
leaseManager,
flintIndexOpFactory,
metricsService);
}

public StreamingQueryHandler getStreamingQueryHandler() {
public StreamingQueryHandler getStreamingQueryHandler(String accountId) {
return new StreamingQueryHandler(
emrServerlessClientFactory.getClient(),
emrServerlessClientFactory.getClient(accountId),
jobExecutionResponseReader,
leaseManager,
metricsService);
}

public BatchQueryHandler getBatchQueryHandler() {
public BatchQueryHandler getBatchQueryHandler(String accountId) {
return new BatchQueryHandler(
emrServerlessClientFactory.getClient(),
emrServerlessClientFactory.getClient(accountId),
jobExecutionResponseReader,
leaseManager,
metricsService);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,15 @@ public DispatchQueryResponse dispatch(
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();

return getQueryHandlerForFlintExtensionQuery(indexQueryDetails)
return getQueryHandlerForFlintExtensionQuery(dispatchQueryRequest, indexQueryDetails)
.submit(dispatchQueryRequest, context);
} else {
DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();
return getDefaultAsyncQueryHandler().submit(dispatchQueryRequest, context);
return getDefaultAsyncQueryHandler(dispatchQueryRequest.getAccountId())
.submit(dispatchQueryRequest, context);
}
}

Expand All @@ -74,28 +75,28 @@ private DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchConte
}

private AsyncQueryHandler getQueryHandlerForFlintExtensionQuery(
IndexQueryDetails indexQueryDetails) {
DispatchQueryRequest dispatchQueryRequest, IndexQueryDetails indexQueryDetails) {
if (isEligibleForIndexDMLHandling(indexQueryDetails)) {
return queryHandlerFactory.getIndexDMLHandler();
} else if (isEligibleForStreamingQuery(indexQueryDetails)) {
return queryHandlerFactory.getStreamingQueryHandler();
return queryHandlerFactory.getStreamingQueryHandler(dispatchQueryRequest.getAccountId());
} else if (IndexQueryActionType.CREATE.equals(indexQueryDetails.getIndexQueryActionType())) {
// Create should be handled by batch handler. This is to avoid DROP index incorrectly cancel
// an interactive job.
return queryHandlerFactory.getBatchQueryHandler();
return queryHandlerFactory.getBatchQueryHandler(dispatchQueryRequest.getAccountId());
} else if (IndexQueryActionType.REFRESH.equals(indexQueryDetails.getIndexQueryActionType())) {
// Manual refresh should be handled by batch handler
return queryHandlerFactory.getRefreshQueryHandler();
return queryHandlerFactory.getRefreshQueryHandler(dispatchQueryRequest.getAccountId());
} else {
return getDefaultAsyncQueryHandler();
return getDefaultAsyncQueryHandler(dispatchQueryRequest.getAccountId());
}
}

@NotNull
private AsyncQueryHandler getDefaultAsyncQueryHandler() {
private AsyncQueryHandler getDefaultAsyncQueryHandler(String accountId) {
return sessionManager.isEnabled()
? queryHandlerFactory.getInteractiveQueryHandler()
: queryHandlerFactory.getBatchQueryHandler();
: queryHandlerFactory.getBatchQueryHandler(accountId);
}

@NotNull
Expand Down Expand Up @@ -143,11 +144,11 @@ private AsyncQueryHandler getAsyncQueryHandlerForExistingQuery(
} else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) {
return queryHandlerFactory.getIndexDMLHandler();
} else if (asyncQueryJobMetadata.getJobType() == JobType.BATCH) {
return queryHandlerFactory.getRefreshQueryHandler();
return queryHandlerFactory.getRefreshQueryHandler(asyncQueryJobMetadata.getAccountId());
} else if (asyncQueryJobMetadata.getJobType() == JobType.STREAMING) {
return queryHandlerFactory.getStreamingQueryHandler();
return queryHandlerFactory.getStreamingQueryHandler(asyncQueryJobMetadata.getAccountId());
} else {
return queryHandlerFactory.getBatchQueryHandler();
return queryHandlerFactory.getBatchQueryHandler(asyncQueryJobMetadata.getAccountId());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public Session createSession(
.sessionId(sessionIdProvider.getSessionId(request))
.sessionStorageService(sessionStorageService)
.statementStorageService(statementStorageService)
.serverlessClient(emrServerlessClientFactory.getClient())
.serverlessClient(emrServerlessClientFactory.getClient(request.getAccountId()))
.build();
session.open(request, asyncQueryRequestContext);
return session;
Expand Down Expand Up @@ -65,7 +65,7 @@ public Optional<Session> getSession(String sessionId, String dataSourceName) {
.sessionId(sessionId)
.sessionStorageService(sessionStorageService)
.statementStorageService(statementStorageService)
.serverlessClient(emrServerlessClientFactory.getClient())
.serverlessClient(emrServerlessClientFactory.getClient(model.get().getAccountId()))
.sessionModel(model.get())
.sessionInactivityTimeoutMilli(
sessionConfigSupplier.getSessionInactivityTimeoutMillis())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ public void cancelStreamingJob(FlintIndexStateModel flintIndexStateModel)
throws InterruptedException, TimeoutException {
String applicationId = flintIndexStateModel.getApplicationId();
String jobId = flintIndexStateModel.getJobId();
EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient();
EMRServerlessClient emrServerlessClient =
emrServerlessClientFactory.getClient(flintIndexStateModel.getAccountId());
try {
emrServerlessClient.cancelJobRun(
flintIndexStateModel.getApplicationId(), flintIndexStateModel.getJobId(), true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,8 @@

package org.opensearch.sql.spark.rest.model;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

import java.io.IOException;
import lombok.Data;
import org.apache.commons.lang3.Validate;
import org.opensearch.core.xcontent.XContentParser;

@Data
public class CreateAsyncQueryRequest {
Expand All @@ -32,35 +28,4 @@ public CreateAsyncQueryRequest(String query, String datasource, LangType lang, S
this.lang = Validate.notNull(lang, "lang can't be null");
this.sessionId = sessionId;
}

public static CreateAsyncQueryRequest fromXContentParser(XContentParser parser)
throws IOException {
String query = null;
LangType lang = null;
String datasource = null;
String sessionId = null;
try {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();
if (fieldName.equals("query")) {
query = parser.textOrNull();
} else if (fieldName.equals("lang")) {
String langString = parser.textOrNull();
lang = LangType.fromString(langString);
} else if (fieldName.equals("datasource")) {
datasource = parser.textOrNull();
} else if (fieldName.equals("sessionId")) {
sessionId = parser.textOrNull();
} else {
throw new IllegalArgumentException("Unknown field: " + fieldName);
}
}
return new CreateAsyncQueryRequest(query, datasource, lang, sessionId);
} catch (Exception e) {
throw new IllegalArgumentException(
String.format("Error while parsing the request body: %s", e.getMessage()));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.sql.spark.utils;

import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
import lombok.Getter;
import lombok.experimental.UtilityClass;
Expand All @@ -18,6 +20,7 @@
import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsParser;
import org.opensearch.sql.spark.antlr.parser.SqlBaseLexer;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.IdentifierReferenceContext;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParserBaseVisitor;
import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions;
import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName;
Expand All @@ -32,16 +35,15 @@
@UtilityClass
public class SQLQueryUtils {

// TODO Handle cases where the query has multiple table Names.
public static FullyQualifiedTableName extractFullyQualifiedTableName(String sqlQuery) {
public static List<FullyQualifiedTableName> extractFullyQualifiedTableNames(String sqlQuery) {
SqlBaseParser sqlBaseParser =
new SqlBaseParser(
new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery))));
sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener());
SqlBaseParser.StatementContext statement = sqlBaseParser.statement();
SparkSqlTableNameVisitor sparkSqlTableNameVisitor = new SparkSqlTableNameVisitor();
statement.accept(sparkSqlTableNameVisitor);
return sparkSqlTableNameVisitor.getFullyQualifiedTableName();
return sparkSqlTableNameVisitor.getFullyQualifiedTableNames();
}

public static IndexQueryDetails extractIndexDetails(String sqlQuery) {
Expand Down Expand Up @@ -73,23 +75,21 @@ public static boolean isFlintExtensionQuery(String sqlQuery) {

public static class SparkSqlTableNameVisitor extends SqlBaseParserBaseVisitor<Void> {

@Getter private FullyQualifiedTableName fullyQualifiedTableName;
@Getter private List<FullyQualifiedTableName> fullyQualifiedTableNames = new LinkedList<>();

public SparkSqlTableNameVisitor() {
this.fullyQualifiedTableName = new FullyQualifiedTableName();
}
public SparkSqlTableNameVisitor() {}

@Override
public Void visitTableName(SqlBaseParser.TableNameContext ctx) {
fullyQualifiedTableName = new FullyQualifiedTableName(ctx.getText());
return super.visitTableName(ctx);
public Void visitIdentifierReference(IdentifierReferenceContext ctx) {
fullyQualifiedTableNames.add(new FullyQualifiedTableName(ctx.getText()));
return super.visitIdentifierReference(ctx);
}

@Override
public Void visitDropTable(SqlBaseParser.DropTableContext ctx) {
for (ParseTree parseTree : ctx.children) {
if (parseTree instanceof SqlBaseParser.IdentifierReferenceContext) {
fullyQualifiedTableName = new FullyQualifiedTableName(parseTree.getText());
fullyQualifiedTableNames.add(new FullyQualifiedTableName(parseTree.getText()));
}
}
return super.visitDropTable(ctx);
Expand All @@ -99,7 +99,7 @@ public Void visitDropTable(SqlBaseParser.DropTableContext ctx) {
public Void visitDescribeRelation(SqlBaseParser.DescribeRelationContext ctx) {
for (ParseTree parseTree : ctx.children) {
if (parseTree instanceof SqlBaseParser.IdentifierReferenceContext) {
fullyQualifiedTableName = new FullyQualifiedTableName(parseTree.getText());
fullyQualifiedTableNames.add(new FullyQualifiedTableName(parseTree.getText()));
}
}
return super.visitDescribeRelation(ctx);
Expand All @@ -110,7 +110,7 @@ public Void visitDescribeRelation(SqlBaseParser.DescribeRelationContext ctx) {
public Void visitCreateTableHeader(SqlBaseParser.CreateTableHeaderContext ctx) {
for (ParseTree parseTree : ctx.children) {
if (parseTree instanceof SqlBaseParser.IdentifierReferenceContext) {
fullyQualifiedTableName = new FullyQualifiedTableName(parseTree.getText());
fullyQualifiedTableNames.add(new FullyQualifiedTableName(parseTree.getText()));
}
}
return super.visitCreateTableHeader(ctx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ public class AsyncQueryCoreIntegTest {
@BeforeEach
public void setUp() {
emrServerlessClientFactory =
() -> new EmrServerlessClientImpl(awsemrServerless, metricsService);
(accountId) -> new EmrServerlessClientImpl(awsemrServerless, metricsService);
SessionManager sessionManager =
new SessionManager(
sessionStorageService,
Expand Down
Loading

0 comments on commit 30c8237

Please sign in to comment.