Skip to content

Commit

Permalink
change implicit casting rules
Browse files Browse the repository at this point in the history
  • Loading branch information
mxwli committed Aug 7, 2024
1 parent 7139a6a commit 8d0316a
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
6 changes: 4 additions & 2 deletions src/function/built_in_function_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,11 @@ uint32_t BuiltInFunctionsUtils::getCastCost(LogicalTypeID inputTypeID, LogicalTy
return castFromRDFVariant(inputTypeID);
}
if (targetTypeID == LogicalTypeID::STRING) {
return castFromString(inputTypeID);
return castToString(inputTypeID);
}
switch (inputTypeID) {
case LogicalTypeID::STRING:
return getTargetTypeCost(targetTypeID);
case LogicalTypeID::INT64:
return castInt64(targetTypeID);
case LogicalTypeID::INT32:
Expand Down Expand Up @@ -391,7 +393,7 @@ uint32_t BuiltInFunctionsUtils::castTimestamp(LogicalTypeID targetTypeID) {
}
}

uint32_t BuiltInFunctionsUtils::castFromString(LogicalTypeID inputTypeID) {
uint32_t BuiltInFunctionsUtils::castToString(LogicalTypeID inputTypeID) {
switch (inputTypeID) {
case LogicalTypeID::BLOB:
case LogicalTypeID::INTERNAL_ID:
Expand Down
2 changes: 1 addition & 1 deletion src/include/function/built_in_function_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class BuiltInFunctionsUtils {

static uint32_t castTimestamp(common::LogicalTypeID targetTypeID);

static uint32_t castFromString(common::LogicalTypeID inputTypeID);
static uint32_t castToString(common::LogicalTypeID inputTypeID);

static uint32_t castFromRDFVariant(common::LogicalTypeID inputTypeID);

Expand Down
11 changes: 11 additions & 0 deletions tools/python_api/test/test_scan_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,3 +409,14 @@ def test_copy_from_pandas_date(tmp_path: Path) -> None:
assert result.get_next() == [1, datetime.datetime(2024,1,3)]
assert result.get_next() == [2, datetime.datetime(2023,10,10)]
assert result.has_next() is False

def test_scan_string_to_list(tmp_path: Path) -> None:
db = kuzu.Database(tmp_path)
conn = kuzu.Connection(db)
df = pd.DataFrame({"col": ["[1, 2, 3]"]})
conn.execute("CREATE NODE TABLE tab(id SERIAL, col INT64[], PRIMARY KEY (id));")
conn.execute("COPY tab FROM df;")
result = conn.execute("match (t:tab) return t.*")
assert result.get_next() == [0, [1, 2, 3]]
assert not result.has_next()

0 comments on commit 8d0316a

Please sign in to comment.