Skip to content

Commit

Permalink
resolve weird ANY resolution (#3160)
Browse files Browse the repository at this point in the history
  • Loading branch information
mxwli authored Mar 28, 2024
1 parent 015bf23 commit 6c82aad
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 24 deletions.
37 changes: 19 additions & 18 deletions tools/python_api/src_cpp/py_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,12 @@ static std::unordered_map<std::string, std::unique_ptr<Value>> transformPythonPa

static bool canCastPyLogicalType(const LogicalType& from, const LogicalType& to) {
// the input of this function is restricted to the output of pyLogicalType
if (from.getLogicalTypeID() == LogicalTypeID::MAP) {
if (from.getLogicalTypeID() == LogicalTypeID::ANY ||
from.getLogicalTypeID() == to.getLogicalTypeID()) {
return true;
} else if (to.getLogicalTypeID() == LogicalTypeID::ANY) {
return false;
} else if (from.getLogicalTypeID() == LogicalTypeID::MAP) {
if (to.getLogicalTypeID() != LogicalTypeID::MAP) {
return false;
}
Expand All @@ -195,9 +200,6 @@ static bool canCastPyLogicalType(const LogicalType& from, const LogicalType& to)
}
return canCastPyLogicalType(
*VarListType::getChildType(&from), *VarListType::getChildType(&to));
} else if (from.getLogicalTypeID() == LogicalTypeID::ANY ||
from.getLogicalTypeID() == to.getLogicalTypeID()) {
return true;
} else {
auto castCost = function::BuiltInFunctionsUtils::getCastCost(
from.getLogicalTypeID(), to.getLogicalTypeID());
Expand All @@ -206,24 +208,27 @@ static bool canCastPyLogicalType(const LogicalType& from, const LogicalType& to)
return false;
}

static void tryConvertPyLogicalType(LogicalType& from, LogicalType& to);

static std::unique_ptr<LogicalType> castPyLogicalType(const LogicalType& from, const LogicalType& to) {
// assumes from can cast to to
if (from.getLogicalTypeID() == LogicalTypeID::MAP) {
auto fromKeyType = MapType::getKeyType(&from), fromValueType = MapType::getValueType(&to);
auto fromKeyType = MapType::getKeyType(&from), fromValueType = MapType::getValueType(&from);
auto toKeyType = MapType::getKeyType(&to), toValueType = MapType::getValueType(&to);
auto outputKeyType = canCastPyLogicalType(*fromKeyType, *toKeyType) ? toKeyType : fromKeyType;
auto outputValueType = canCastPyLogicalType(*fromValueType, *toValueType) ? toValueType : fromValueType;
return LogicalType::MAP(
std::make_unique<LogicalType>(*outputKeyType), std::make_unique<LogicalType>(*outputValueType));
auto outputKeyType = canCastPyLogicalType(*fromKeyType, *toKeyType) ?
castPyLogicalType(*fromKeyType, *toKeyType) : castPyLogicalType(*toKeyType, *fromKeyType);
auto outputValueType = canCastPyLogicalType(*fromValueType, *toValueType) ?
castPyLogicalType(*fromValueType, *toValueType) : castPyLogicalType(*toValueType, *fromValueType);
return LogicalType::MAP(std::move(outputKeyType), std::move(outputValueType));
}
return std::make_unique<LogicalType>(to);
}

static void tryConvertPyLogicalType(LogicalType& from, LogicalType& to) {
void tryConvertPyLogicalType(LogicalType& from, LogicalType& to) {
if (canCastPyLogicalType(from, to)) {
from = *castPyLogicalType(from, to);
} else if (canCastPyLogicalType(to, from)) {
from = *castPyLogicalType(to, from);
to = *castPyLogicalType(to, from);
} else {
throw RuntimeException(stringFormat(
"Cannot convert Python object to Kuzu value : {} is incompatible with {}",
Expand Down Expand Up @@ -257,9 +262,6 @@ static std::unique_ptr<LogicalType> pyLogicalType(py::handle val) {
} else if (py::isinstance<py::list>(val)) {
py::list lst = py::reinterpret_borrow<py::list>(val);
auto childType = LogicalType::ANY();
if (py::len(lst) == 0) {
childType = LogicalType::STRING();
}
for (auto child : lst) {
auto curChildType = pyLogicalType(child);
tryConvertPyLogicalType(*childType, *curChildType);
Expand All @@ -268,10 +270,6 @@ static std::unique_ptr<LogicalType> pyLogicalType(py::handle val) {
} else if (py::isinstance<py::dict>(val)) {
py::dict dict = py::reinterpret_borrow<py::dict>(val);
auto childKeyType = LogicalType::ANY(), childValueType = LogicalType::ANY();
if (py::len(dict) == 0) {
childKeyType = LogicalType::STRING();
childValueType = LogicalType::STRING();
}
for (auto child : dict) {
auto curChildKeyType = pyLogicalType(child.first),
curChildValueType = pyLogicalType(child.second);
Expand All @@ -289,6 +287,9 @@ static std::unique_ptr<LogicalType> pyLogicalType(py::handle val) {

static Value transformPythonValueAs(py::handle val, const LogicalType* type) {
// ignore the type of the actual python object, just directly cast
if (val.is_none()) {
return Value::createNullValue(*type);
}
switch (type->getLogicalTypeID()) {
case LogicalTypeID::ANY:
return Value::createNullValue();
Expand Down
38 changes: 32 additions & 6 deletions tools/python_api/test/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,23 +97,20 @@ def test_string_list_param(conn_db_readonly: ConnDB) -> None:
assert result.has_next()
assert result.get_next() == [1]
assert not result.has_next()
result = conn.execute(
"MATCH (a:person {usedNames: $1}) RETURN COUNT(*);",
{"1": []}) # empty list is string
result.close()

def test_map_param(tmp_path: Path) -> None:
db = kuzu.Database(tmp_path)
conn = kuzu.Connection(db)
conn.execute("CREATE NODE TABLE tab(id int64, mp MAP(double, int64), mp2 MAP(int64, double), mp3 MAP(string, string), mp4 MAP(string, string), primary key(id))")
conn.execute("CREATE NODE TABLE tab(id int64, mp MAP(double, int64), mp2 MAP(int64, double), mp3 MAP(string, string), mp4 MAP(string, string)[], primary key(id))")
result = conn.execute(
"MERGE (t:tab {id: 0, mp: $1, mp2: $2, mp3: $3, mp4: $4}) RETURN t.*",
{"1": {1.0: 5, 2: 3, 2.2: -1},
"2": {5: -0.5, 4: 0, 0: 2.2},
"3": {'a': 1, 'b': '2', 'c': '3'},
"4": {}})
"4": [{}, {'a': 'b'}]})
assert result.has_next()
assert result.get_next() == [0, {1.0: 5, 2.0: 3, 2.2: -1}, {5: -0.5, 4: -0.0, 0: 2.2}, {'a': '1', 'b': '2', 'c': '3'}, {}]
assert result.get_next() == [0, {1.0: 5, 2.0: 3, 2.2: -1}, {5: -0.5, 4: -0.0, 0: 2.2}, {'a': '1', 'b': '2', 'c': '3'}, [{}, {'a': 'b'}]]
assert not result.has_next()
result.close()

Expand All @@ -139,6 +136,35 @@ def test_general_list_param(tmp_path: Path) -> None:
assert not result.has_next()
result.close()

def test_null_resolution(tmp_path: Path) -> None:
db = kuzu.Database(tmp_path)
conn = kuzu.Connection(db)
conn.execute(
"CREATE NODE TABLE tab(id SERIAL, lst1 INT64[], mp1 MAP(STRING, STRING), nest MAP(STRING, MAP(STRING, INT64))[], PRIMARY KEY(id))")
lst1 = [1, 2, 3, None]
mp1 = {'a': 'x', 'b': 'y', 'c': 'z', 'o': None}
nest = [{'a': {'foo' : 1, 'bar' : 2}}, {1: {}}]
result = conn.execute(
"MERGE (t:tab {lst1: $1, mp1: $2, nest: $3}) RETURN t.*",
{'1': lst1, '2': mp1, '3': nest})
assert result.has_next()
assert result.get_next() == [0, lst1, mp1, [{'a': {'foo' : 1, 'bar' : 2}}, {'1': {}}]]
assert not result.has_next()
result.close()

# def test_param_empty(tmp_path: Path) -> None:
# db = kuzu.Database(tmp_path)
# conn = kuzu.Connection(db)
# lst = [[]]
# result = conn.execute("CREATE NODE TABLE tab(id SERIAL, lst INT64[][], PRIMARY KEY(id))")
# result = conn.execute(
# "MERGE (t:tab {lst: $1}) RETURN t.*",
# {'1': lst})
# assert result.has_next()
# assert result.get_next == [0, lst]
# assert not result.has_next()
# result.close()

def test_param_error1(conn_db_readonly: ConnDB) -> None:
conn, db = conn_db_readonly
with pytest.raises(RuntimeError, match="Parameter name must be of type string but got <class 'int'>"):
Expand Down

0 comments on commit 6c82aad

Please sign in to comment.