From ee7427350f048bde21219a9783857bc96d9a3af7 Mon Sep 17 00:00:00 2001 From: Shengkai <1059623455@qq.com> Date: Thu, 25 Sep 2025 11:17:28 +0800 Subject: [PATCH 1/2] [FLINK-38423][table-api] Add VECTOR_SEARCH connector API --- .../source/VectorSearchTableSource.java | 117 ++++++++++++++++++ .../AsyncVectorSearchFunctionProvider.java | 38 ++++++ .../search/VectorSearchFunctionProvider.java | 37 ++++++ .../functions/AsyncVectorSearchFunction.java | 66 ++++++++++ .../table/functions/VectorSearchFunction.java | 62 ++++++++++ 5 files changed, 320 insertions(+) create mode 100644 flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/VectorSearchTableSource.java create mode 100644 flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/search/AsyncVectorSearchFunctionProvider.java create mode 100644 flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/search/VectorSearchFunctionProvider.java create mode 100644 flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/AsyncVectorSearchFunction.java create mode 100644 flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/VectorSearchFunction.java diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/VectorSearchTableSource.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/VectorSearchTableSource.java new file mode 100644 index 0000000000000..b10a9896dae9b --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/VectorSearchTableSource.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.connector.source; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.configuration.ReadableConfig; +import org.apache.flink.table.connector.source.search.AsyncVectorSearchFunctionProvider; +import org.apache.flink.table.connector.source.search.VectorSearchFunctionProvider; +import org.apache.flink.types.RowKind; + +import java.io.Serializable; + +/** + * A {@link DynamicTableSource} that search rows of an external storage system by one or more + * vectors during runtime. + * + *

Compared to {@link ScanTableSource}, the source does not have to read the entire table and can + * lazily fetch individual values from a (possibly continuously changing) external table when + * necessary. + * + *

Note: Compared to {@link ScanTableSource}, a {@link VectorSearchTableSource} does only support + * emitting insert-only changes currently (see also {@link RowKind}). Further abilities are not + * supported. + * + *

In the last step, the planner will call {@link #getSearchRuntimeProvider(VectorSearchContext)} + * for obtaining a provider of runtime implementation. The search fields that are required to + * perform a search are derived from a query by the planner and will be provided in the given {@link + * VectorSearchTableSource.VectorSearchContext#getSearchColumns()}. The values for those key fields + * are passed during runtime. + */ +@PublicEvolving +public interface VectorSearchTableSource extends DynamicTableSource { + + /** + * Returns a provider of runtime implementation for reading the data. + * + *

There exist different interfaces for runtime implementation which is why {@link + * VectorSearchRuntimeProvider} serves as the base interface. + * + *

Independent of the provider interface, a source implementation can work on either + * arbitrary objects or internal data structures (see {@link org.apache.flink.table.data} for + * more information). + * + *

The given {@link VectorSearchContext} offers utilities by the planner for creating runtime + * implementation with minimal dependencies to internal data structures. + * + * @see VectorSearchFunctionProvider + * @see AsyncVectorSearchFunctionProvider + */ + VectorSearchRuntimeProvider getSearchRuntimeProvider(VectorSearchContext context); + + // -------------------------------------------------------------------------------------------- + // Helper interfaces + // -------------------------------------------------------------------------------------------- + + /** + * Context for creating runtime implementation via a {@link VectorSearchRuntimeProvider}. + * + *

It offers utilities by the planner for creating runtime implementation with minimal + * dependencies to internal data structures. + * + *

Methods should be called in {@link #getSearchRuntimeProvider(VectorSearchContext)}. + * Returned instances that are {@link Serializable} can be directly passed into the runtime + * implementation class. + */ + @PublicEvolving + interface VectorSearchContext extends DynamicTableSource.Context { + + /** + * Returns an array of key index paths that should be used during the search. The indices + * are 0-based and support composite keys within (possibly nested) structures. + * + *

For example, given a table with data type {@code ROW < i INT, s STRING, r ROW < i2 + * INT, s2 STRING > >}, this method would return {@code [[0], [2, 1]]} when {@code i} and + * {@code s2} are used for performing a lookup. + * + * @return array of key index paths + */ + int[][] getSearchColumns(); + + /** + * Runtime config provided to provider. The config can be used by planner or vector search + * provider at runtime. For example, async options can be used by planner to choose async + * inference. Other config such as http timeout or retry can be used to configure search + * functions. + */ + ReadableConfig runtimeConfig(); + } + + /** + * Provides actual runtime implementation for reading the data. + * + *

There exist different interfaces for runtime implementation which is why {@link + * VectorSearchRuntimeProvider} serves as the base interface. + * + * @see VectorSearchFunctionProvider + * @see AsyncVectorSearchFunctionProvider + */ + @PublicEvolving + interface VectorSearchRuntimeProvider {} +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/search/AsyncVectorSearchFunctionProvider.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/search/AsyncVectorSearchFunctionProvider.java new file mode 100644 index 0000000000000..9dd7a5083dca8 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/search/AsyncVectorSearchFunctionProvider.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.connector.source.search; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.table.connector.source.VectorSearchTableSource; +import org.apache.flink.table.functions.AsyncVectorSearchFunction; + +/** A provider for creating {@link AsyncVectorSearchFunction}. */ +@PublicEvolving +public interface AsyncVectorSearchFunctionProvider + extends VectorSearchTableSource.VectorSearchRuntimeProvider { + + /** Helper function for creating a static provider. */ + static AsyncVectorSearchFunctionProvider of( + AsyncVectorSearchFunction asyncVectorSearchFunction) { + return () -> asyncVectorSearchFunction; + } + + /** Creates an {@link AsyncVectorSearchFunction} instance. */ + AsyncVectorSearchFunction createAsyncVectorSearchFunction(); +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/search/VectorSearchFunctionProvider.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/search/VectorSearchFunctionProvider.java new file mode 100644 index 0000000000000..fe50ad585df3c --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/search/VectorSearchFunctionProvider.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.connector.source.search; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.table.connector.source.VectorSearchTableSource; +import org.apache.flink.table.functions.VectorSearchFunction; + +/** A provider for creating {@link VectorSearchFunction}. */ +@PublicEvolving +public interface VectorSearchFunctionProvider + extends VectorSearchTableSource.VectorSearchRuntimeProvider { + + /** Helper function for creating a static provider. */ + static VectorSearchFunctionProvider of(VectorSearchFunction searchFunction) { + return () -> searchFunction; + } + + /** Creates an {@link VectorSearchFunction} instance. */ + VectorSearchFunction createVectorSearchFunction(); +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/AsyncVectorSearchFunction.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/AsyncVectorSearchFunction.java new file mode 100644 index 0000000000000..5641f559e9e57 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/AsyncVectorSearchFunction.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.table.api.TableException; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; + +import java.util.Collection; +import java.util.concurrent.CompletableFuture; + +/** + * A wrapper class of {@link AsyncTableFunction} for asynchronous vector search. + * + *

The output type of this table function is fixed as {@link RowData}. + */ +@PublicEvolving +public abstract class AsyncVectorSearchFunction extends AsyncTableFunction { + + /** + * Asynchronously search result based on input row to find topK matched rows. + * + * @param topK - The number of topK matched rows to return. + * @param queryData - A {@link RowData} that wraps input for search function. + * @return A collection of all searched results. + */ + public abstract CompletableFuture> asyncVectorSearch( + int topK, RowData queryData); + + /** Invokes {@link #asyncVectorSearch} and chains futures. */ + public void eval(CompletableFuture> future, Object... args) { + int topK = (int) args[0]; + GenericRowData argsData = GenericRowData.of(args[1]); + asyncVectorSearch(topK, argsData) + .whenComplete( + (result, exception) -> { + if (exception != null) { + future.completeExceptionally( + new TableException( + String.format( + "Failed to execute asynchronously search with input row %s.", + argsData), + exception)); + return; + } + future.complete(result); + }); + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/VectorSearchFunction.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/VectorSearchFunction.java new file mode 100644 index 0000000000000..3364e56de9b1f --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/VectorSearchFunction.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.util.FlinkRuntimeException; + +import java.io.IOException; +import java.util.Collection; + +/** + * A wrapper class of {@link TableFunction} for synchronous vector search. + * + *

The output type of this table function is fixed as {@link RowData}. + */ +@PublicEvolving +public abstract class VectorSearchFunction extends TableFunction { + + /** + * Synchronously search result based on input row to find topK matched rows. + * + * @param topK - The number of topK results to return. + * @param queryData - A {@link RowData} that wraps input for vector search function. + * @return A collection of predicted results. + */ + public abstract Collection vectorSearch(int topK, RowData queryData) + throws IOException; + + /** Invoke {@link #vectorSearch} and handle exceptions. */ + public final void eval(Object... args) { + int topK = (int) args[0]; + GenericRowData argsData = GenericRowData.of(args[1]); + try { + Collection results = vectorSearch(topK, argsData); + if (results == null) { + return; + } + results.forEach(this::collect); + } catch (Exception e) { + throw new FlinkRuntimeException( + String.format("Failed to execute search with input row %s.", argsData), e); + } + } +} From 620992966aea5d30bbc4433b156fa66532eb85b3 Mon Sep 17 00:00:00 2001 From: Shengkai <1059623455@qq.com> Date: Thu, 25 Sep 2025 14:49:07 +0800 Subject: [PATCH 2/2] [FLINK-38424][planner] Support to parse VECTOR_SEARCH function --- .../sql/validate/SqlValidatorImpl.java | 18 ++ .../functions/sql/FlinkSqlOperatorTable.java | 4 + .../sql/ml/SqlVectorSearchTableFunction.java | 239 ++++++++++++++++++ .../functions/utils/SqlValidatorUtils.java | 22 +- .../sql/VectorSearchTableFunctionTest.java | 212 ++++++++++++++++ .../sql/VectorSearchTableFunctionTest.xml | 141 +++++++++++ 6 files changed, 628 insertions(+), 8 deletions(-) create mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlVectorSearchTableFunction.java create mode 100644 flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.java create mode 100644 flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.xml diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java index 6e94d8979746c..3aa57ed967ec1 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java @@ -17,6 +17,7 @@ package org.apache.calcite.sql.validate; import org.apache.flink.table.planner.calcite.FlinkSqlCallBinding; +import org.apache.flink.table.planner.functions.sql.ml.SqlVectorSearchTableFunction; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -173,6 +174,8 @@ *

Lines 2012 ~ 2032, Flink improves error message for functions without appropriate arguments in * handleUnresolvedFunction at {@link SqlValidatorImpl#handleUnresolvedFunction}. * + *

Lines 2608 ~ 2619, Flink sets correct scope for {@link SqlVectorSearchTableFunction}. + * *

Lines 3840 ~ 3844, 6511 ~ 6517 Flink improves Optimize the retrieval of sub-operands in * SqlCall when using NamedParameters at {@link SqlValidatorImpl#checkRollUp}. * @@ -2599,6 +2602,21 @@ private SqlNode registerFrom( scopes.put(node, getSelectScope(call1.operand(0))); return newNode; } + + // Related to CALCITE-4077 + // ----- FLINK MODIFICATION BEGIN ----- + FlinkSqlCallBinding binding = + new FlinkSqlCallBinding(this, getEmptyScope(), call1); + if (op instanceof SqlVectorSearchTableFunction + && binding.operand(0) + .isA( + new HashSet<>( + Collections.singletonList(SqlKind.SELECT)))) { + SqlValidatorScope scope = getSelectScope((SqlSelect) binding.operand(0)); + scopes.put(node, scope); + return newNode; + } + // ----- FLINK MODIFICATION END ----- } // Put the usingScope which can be a JoinScope // or a SelectScope, in order to see the left items diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java index bae972abaeb52..2178bf749d4cf 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java @@ -23,6 +23,7 @@ import org.apache.flink.table.planner.functions.sql.internal.SqlAuxiliaryGroupAggFunction; import org.apache.flink.table.planner.functions.sql.ml.SqlMLEvaluateTableFunction; import org.apache.flink.table.planner.functions.sql.ml.SqlMLPredictTableFunction; +import org.apache.flink.table.planner.functions.sql.ml.SqlVectorSearchTableFunction; import org.apache.flink.table.planner.plan.type.FlinkReturnTypes; import org.apache.flink.table.planner.plan.type.NumericExceptFirstOperandChecker; @@ -1330,6 +1331,9 @@ public List getAuxiliaryFunctions() { public static final SqlFunction ML_PREDICT = new SqlMLPredictTableFunction(); public static final SqlFunction ML_EVALUATE = new SqlMLEvaluateTableFunction(); + // SEARCH FUNCTIONS + public static final SqlFunction VECTOR_SEARCH = new SqlVectorSearchTableFunction(); + // Catalog Functions public static final SqlFunction CURRENT_DATABASE = BuiltInSqlFunction.newBuilder() diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlVectorSearchTableFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlVectorSearchTableFunction.java new file mode 100644 index 0000000000000..a655efdf9f072 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlVectorSearchTableFunction.java @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.functions.sql.ml; + +import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.planner.functions.utils.SqlValidatorUtils; +import org.apache.flink.table.types.logical.ArrayType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.LogicalTypeRoot; +import org.apache.flink.table.types.logical.utils.LogicalTypeCasts; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeFieldImpl; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlCallBinding; +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlOperandCountRange; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.SqlTableFunction; +import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlOperandCountRanges; +import org.apache.calcite.sql.type.SqlOperandMetadata; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.validate.SqlNameMatcher; +import org.apache.calcite.util.Util; +import org.checkerframework.checker.nullness.qual.Nullable; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import static org.apache.flink.table.planner.calcite.FlinkTypeFactory.toLogicalType; + +/** + * {@link SqlVectorSearchTableFunction} implements an operator for search. + * + *

It allows four parameters: + * + *

    + *
  1. a table + *
  2. a descriptor to provide a column name from the input table + *
  3. a query column from the left table + *
  4. a literal value for top k + *
+ */ +public class SqlVectorSearchTableFunction extends SqlFunction implements SqlTableFunction { + + private static final String PARAM_SEARCH_TABLE = "SEARCH_TABLE"; + private static final String PARAM_COLUMN_TO_SEARCH = "COLUMN_TO_SEARCH"; + private static final String PARAM_COLUMN_TO_QUERY = "COLUMN_TO_QUERY"; + private static final String PARAM_TOP_K = "TOP_K"; + + private static final String OUTPUT_SCORE = "score"; + + public SqlVectorSearchTableFunction() { + super( + "VECTOR_SEARCH", + SqlKind.OTHER_FUNCTION, + ReturnTypes.CURSOR, + null, + new OperandMetadataImpl(), + SqlFunctionCategory.SYSTEM); + } + + @Override + public SqlReturnTypeInference getRowTypeInference() { + return new SqlReturnTypeInference() { + @Override + public @Nullable RelDataType inferReturnType(SqlOperatorBinding opBinding) { + final RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); + final RelDataType inputRowType = opBinding.getOperandType(0); + + return typeFactory + .builder() + .kind(inputRowType.getStructKind()) + .addAll(inputRowType.getFieldList()) + .addAll( + SqlValidatorUtils.makeOutputUnique( + inputRowType.getFieldList(), + Collections.singletonList( + new RelDataTypeFieldImpl( + OUTPUT_SCORE, + 0, + typeFactory.createSqlType( + SqlTypeName.DOUBLE))))) + .build(); + } + }; + } + + @Override + public boolean argumentMustBeScalar(int ordinal) { + return ordinal != 0; + } + + private static class OperandMetadataImpl implements SqlOperandMetadata { + + private static final List PARAMETERS = + Collections.unmodifiableList( + Arrays.asList( + PARAM_SEARCH_TABLE, + PARAM_COLUMN_TO_SEARCH, + PARAM_COLUMN_TO_QUERY, + PARAM_TOP_K)); + + @Override + public List paramTypes(RelDataTypeFactory relDataTypeFactory) { + return Collections.nCopies( + PARAMETERS.size(), relDataTypeFactory.createSqlType(SqlTypeName.ANY)); + } + + @Override + public List paramNames() { + return PARAMETERS; + } + + @Override + public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) { + // check vector table contains descriptor columns + if (!SqlValidatorUtils.checkTableAndDescriptorOperands(callBinding, 1)) { + return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse( + callBinding, throwOnFailure); + } + + List operands = callBinding.operands(); + // check descriptor has one column + SqlCall descriptor = (SqlCall) operands.get(1); + List descriptorCols = descriptor.getOperandList(); + if (descriptorCols.size() != 1) { + return SqlValidatorUtils.throwExceptionOrReturnFalse( + Optional.of( + new ValidationException( + String.format( + "Expect parameter COLUMN_TO_SEARCH for VECTOR_SEARCH only contains one column, but multiple columns are found in operand %s.", + descriptor))), + throwOnFailure); + } + + // check descriptor type is ARRAY or ARRAY + RelDataType searchTableType = callBinding.getOperandType(0); + SqlNameMatcher matcher = callBinding.getValidator().getCatalogReader().nameMatcher(); + SqlIdentifier columnName = (SqlIdentifier) descriptorCols.get(0); + String descriptorColName = + columnName.isSimple() ? columnName.getSimple() : Util.last(columnName.names); + int index = matcher.indexOf(searchTableType.getFieldNames(), descriptorColName); + RelDataType targetType = searchTableType.getFieldList().get(index).getType(); + LogicalType targetLogicalType = toLogicalType(targetType); + + if (!(targetLogicalType.is(LogicalTypeRoot.ARRAY) + && ((ArrayType) (targetLogicalType)) + .getElementType() + .isAnyOf(LogicalTypeRoot.FLOAT, LogicalTypeRoot.DOUBLE))) { + return SqlValidatorUtils.throwExceptionOrReturnFalse( + Optional.of( + new ValidationException( + String.format( + "Expect search column `%s` type is ARRAY or ARRAY, but its type is %s.", + columnName, targetType))), + throwOnFailure); + } + + // check query type is ARRAY or ARRAY + LogicalType sourceLogicalType = toLogicalType(callBinding.getOperandType(2)); + if (!LogicalTypeCasts.supportsImplicitCast(sourceLogicalType, targetLogicalType)) { + return SqlValidatorUtils.throwExceptionOrReturnFalse( + Optional.of( + new ValidationException( + String.format( + "Can not cast the query column type %s to target type %s. Please keep the query column type is same to the search column type.", + sourceLogicalType, targetType))), + throwOnFailure); + } + + // check topK is literal + LogicalType topKType = toLogicalType(callBinding.getOperandType(3)); + if (!operands.get(3).getKind().equals(SqlKind.LITERAL) + || !topKType.is(LogicalTypeRoot.INTEGER)) { + return SqlValidatorUtils.throwExceptionOrReturnFalse( + Optional.of( + new ValidationException( + String.format( + "Expect parameter topK is integer literal in VECTOR_SEARCH, but it is %s with type %s.", + operands.get(3), topKType))), + throwOnFailure); + } + + return true; + } + + @Override + public SqlOperandCountRange getOperandCountRange() { + return SqlOperandCountRanges.between(4, 4); + } + + @Override + public String getAllowedSignatures(SqlOperator op, String opName) { + return opName + "(TABLE table_name, DESCRIPTOR(query_column), search_column, top_k)"; + } + + @Override + public Consistency getConsistency() { + return Consistency.NONE; + } + + @Override + public boolean isOptional(int i) { + return false; + } + + @Override + public boolean isFixedParameters() { + return true; + } + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/utils/SqlValidatorUtils.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/utils/SqlValidatorUtils.java index 42e381a606290..66b58499e0949 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/utils/SqlValidatorUtils.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/utils/SqlValidatorUtils.java @@ -160,27 +160,33 @@ private static void adjustTypeForMultisetConstructor( /** * Make output field names unique from input field names by appending index. For example, Input * has field names {@code a, b, c} and output has field names {@code b, c, d}. After calling - * this function, new output field names will be {@code b0, c0, d}. Duplicate names are not - * checked inside input and output itself. + * this function, new output field names will be {@code b0, c0, d}. + * + *

We assume that input fields in the input parameter are uniquely named, just as the output + * fields in the output parameter are. * * @param input Input fields * @param output Output fields - * @return + * @return output fields with unique names. */ public static List makeOutputUnique( List input, List output) { - final Set inputFieldNames = new HashSet<>(); + final Set uniqueNames = new HashSet<>(); for (RelDataTypeField field : input) { - inputFieldNames.add(field.getName()); + uniqueNames.add(field.getName()); } List result = new ArrayList<>(); for (RelDataTypeField field : output) { String fieldName = field.getName(); - if (inputFieldNames.contains(fieldName)) { - fieldName += "0"; // Append index to make it unique + int count = 0; + String candidate = fieldName; + while (uniqueNames.contains(candidate)) { + candidate = fieldName + count; + count++; } - result.add(new RelDataTypeFieldImpl(fieldName, field.getIndex(), field.getType())); + uniqueNames.add(candidate); + result.add(new RelDataTypeFieldImpl(candidate, field.getIndex(), field.getType())); } return result; } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.java new file mode 100644 index 0000000000000..818abc6b6d6dd --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.java @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.stream.sql; + +import org.apache.flink.core.testutils.FlinkAssertions; +import org.apache.flink.table.api.TableConfig; +import org.apache.flink.table.api.TableException; +import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.planner.functions.sql.ml.SqlVectorSearchTableFunction; +import org.apache.flink.table.planner.utils.TableTestBase; +import org.apache.flink.table.planner.utils.TableTestUtil; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Test for {@link SqlVectorSearchTableFunction}. */ +public class VectorSearchTableFunctionTest extends TableTestBase { + + private TableTestUtil util; + + @BeforeEach + public void setup() { + util = streamTestUtil(TableConfig.getDefault()); + + // Create test table + util.tableEnv() + .executeSql( + "CREATE TABLE QueryTable (\n" + + " a INT,\n" + + " b BIGINT,\n" + + " c STRING,\n" + + " d ARRAY,\n" + + " rowtime TIMESTAMP(3),\n" + + " proctime as PROCTIME(),\n" + + " WATERMARK FOR rowtime AS rowtime - INTERVAL '1' SECOND\n" + + ") with (\n" + + " 'connector' = 'values'\n" + + ")"); + + util.tableEnv() + .executeSql( + "CREATE TABLE VectorTable (\n" + + " e INT,\n" + + " f BIGINT,\n" + + " g ARRAY\n" + + ") with (\n" + + " 'connector' = 'values'\n" + + ")"); + } + + @Test + void testSimple() { + String sql = + "SELECT * FROM QueryTable, LATERAL TABLE(\n" + + "VECTOR_SEARCH(\n" + + " TABLE VectorTable, DESCRIPTOR(`g`), QueryTable.d, 10" + + ")\n" + + ")"; + util.verifyRelPlan(sql); + } + + @Test + void testLiteralValue() { + String sql = + "SELECT * FROM LATERAL TABLE(VECTOR_SEARCH(TABLE VectorTable, DESCRIPTOR(`g`), ARRAY[1.5, 2.0], 10))"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .satisfies( + FlinkAssertions.anyCauseMatches( + TableException.class, + "FlinkLogicalTableFunctionScan(invocation=[VECTOR_SEARCH(TABLE(#0), DESCRIPTOR(_UTF-16LE'g'), ARRAY(1.5:DECIMAL(2, 1), 2.0:DECIMAL(2, 1)), 10)], rowType=[RecordType(INTEGER e, BIGINT f, FLOAT ARRAY g, DOUBLE score)])\n" + + "+- FlinkLogicalTableSourceScan(table=[[default_catalog, default_database, VectorTable]], fields=[e, f, g])")); + } + + @Test + void testNamedArgument() { + String sql = + "SELECT * FROM QueryTable, LATERAL TABLE(\n" + + "VECTOR_SEARCH(\n" + + " SEARCH_TABLE => TABLE VectorTable,\n" + + " COLUMN_TO_QUERY => QueryTable.d,\n" + + " COLUMN_TO_SEARCH => DESCRIPTOR(`g`),\n" + + " TOP_K => 10" + + " )\n" + + ")"; + util.verifyRelPlan(sql); + } + + @Test + void testOutOfOrderNamedArgument() { + String sql = + "SELECT * FROM QueryTable, LATERAL TABLE(\n" + + "VECTOR_SEARCH(\n" + + " COLUMN_TO_QUERY => QueryTable.d,\n" + + " COLUMN_TO_SEARCH => DESCRIPTOR(`g`),\n" + + " TOP_K => 10,\n" + + " SEARCH_TABLE => TABLE VectorTable\n" + + " )\n" + + ")"; + util.verifyRelPlan(sql); + } + + @Test + void testNameConflicts() { + util.tableEnv() + .executeSql( + "CREATE TABLE NameConflictTable(\n" + + " a INT,\n" + + " score ARRAY,\n" + + " score0 ARRAY,\n" + + " score1 ARRAY\n" + + ") WITH (\n" + + " 'connector' = 'values'\n" + + ")"); + util.verifyRelPlan( + "SELECT * FROM QueryTable, LATERAL TABLE(\n" + + "VECTOR_SEARCH(\n" + + " TABLE NameConflictTable, DESCRIPTOR(`score`), QueryTable.d, 10))"); + } + + @Test + void testDescriptorTypeIsNotExpected() { + String sql = + "SELECT * FROM QueryTable, LATERAL TABLE(\n" + + "VECTOR_SEARCH(\n" + + " TABLE VectorTable, DESCRIPTOR(`f`), QueryTable.d, 10" + + ")\n" + + ")"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .satisfies( + FlinkAssertions.anyCauseMatches( + ValidationException.class, + "Expect search column `f` type is ARRAY or ARRAY, but its type is BIGINT.")); + } + + @Test + void testDescriptorContainsMultipleColumns() { + String sql = + "SELECT * FROM QueryTable, LATERAL TABLE(\n" + + "VECTOR_SEARCH(\n" + + " TABLE VectorTable, DESCRIPTOR(`f`, `g`), QueryTable.d, 10" + + ")\n" + + ")"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .satisfies( + FlinkAssertions.anyCauseMatches( + ValidationException.class, + "Expect parameter COLUMN_TO_SEARCH for VECTOR_SEARCH only contains one column, but multiple columns are found in operand DESCRIPTOR(`f`, `g`).")); + } + + @Test + void testQueryColumnIsNotArray() { + String sql = + "SELECT * FROM QueryTable, LATERAL TABLE(\n" + + "VECTOR_SEARCH(\n" + + " TABLE VectorTable, DESCRIPTOR(`g`), QueryTable.c, 10" + + ")\n" + + ")"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .satisfies( + FlinkAssertions.anyCauseMatches( + ValidationException.class, + "Can not cast the query column type STRING to target type FLOAT ARRAY. Please keep the query column type is same to the search column type.")); + } + + @Test + void testIllegalTopKValue1() { + String sql = + "SELECT * FROM QueryTable, LATERAL TABLE(\n" + + "VECTOR_SEARCH(\n" + + " TABLE VectorTable, DESCRIPTOR(`g`), QueryTable.d, 10.0" + + ")\n" + + ")"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .satisfies( + FlinkAssertions.anyCauseMatches( + ValidationException.class, + "Expect parameter topK is integer literal in VECTOR_SEARCH, but it is 10.0 with type DECIMAL(3, 1) NOT NULL.")); + } + + @Test + void testIllegalTopKValue2() { + String sql = + "SELECT * FROM QueryTable, LATERAL TABLE(\n" + + "VECTOR_SEARCH(\n" + + " TABLE VectorTable, DESCRIPTOR(`g`), QueryTable.d, QueryTable.a" + + ")\n" + + ")"; + assertThatThrownBy(() -> util.verifyRelPlan(sql)) + .satisfies( + FlinkAssertions.anyCauseMatches( + ValidationException.class, + "Expect parameter topK is integer literal in VECTOR_SEARCH, but it is QueryTable.a with type INT.")); + } +} diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.xml new file mode 100644 index 0000000000000..8aca81dc52d4c --- /dev/null +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/VectorSearchTableFunctionTest.xml @@ -0,0 +1,141 @@ + + + + + + + + + + + + + + + + + TABLE VectorTable, + COLUMN_TO_QUERY => QueryTable.d, + COLUMN_TO_SEARCH => DESCRIPTOR(`g`), + TOP_K => 10 ) +)]]> + + + + + + + + + + + + + + + + + + + + + + QueryTable.d, + COLUMN_TO_SEARCH => DESCRIPTOR(`g`), + TOP_K => 10, + SEARCH_TABLE => TABLE VectorTable + ) +)]]> + + + + + + + + +