Skip to content

Commit

Permalink
Support varying types in LHS and RHS of IN operator (#3828)
Browse files Browse the repository at this point in the history
  • Loading branch information
marsaly79 authored May 30, 2024
1 parent 08261c7 commit 3edc507
Show file tree
Hide file tree
Showing 78 changed files with 4,074 additions and 6,570 deletions.
17 changes: 0 additions & 17 deletions ydb/library/yql/core/common_opt/yql_co_pgselect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4093,23 +4093,6 @@ TExprNode::TPtr ExpandPgLike(const TExprNode::TPtr& node, TExprContext& ctx, TOp
.Build();
}

TExprNode::TPtr ExpandPgIn(const TExprNode::TPtr& node, TExprContext& ctx, TOptimizeContext& optCtx) {
Y_UNUSED(optCtx);
return ctx.Builder(node->Pos())
.Callable("ToPg")
.Callable(0, "SqlIn")
.Add(0, node->ChildPtr(1))
.Add(1, node->ChildPtr(0))
.List(2)
.List(0)
.Atom(0, "ansi")
.Seal()
.Seal()
.Seal()
.Seal()
.Build();
}

TExprNode::TPtr ExpandPgBetween(const TExprNode::TPtr& node, TExprContext& ctx, TOptimizeContext& optCtx) {
Y_UNUSED(optCtx);
const bool isSym = node->IsCallable("PgBetweenSym");
Expand Down
2 changes: 0 additions & 2 deletions ydb/library/yql/core/common_opt/yql_co_pgselect.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ TExprNode::TPtr NormalizeColumnOrder(const TExprNode::TPtr& node, const TColumnO

TExprNode::TPtr ExpandPgLike(const TExprNode::TPtr& node, TExprContext& ctx, TOptimizeContext& optCtx);

TExprNode::TPtr ExpandPgIn(const TExprNode::TPtr& node, TExprContext& ctx, TOptimizeContext& optCtx);

TExprNode::TPtr ExpandPgBetween(const TExprNode::TPtr& node, TExprContext& ctx, TOptimizeContext& optCtx);

TExprNode::TPtr ExpandPgGroupRef(const TExprNode::TPtr& node, TExprContext& ctx, TOptimizeContext& optCtx);
Expand Down
2 changes: 0 additions & 2 deletions ydb/library/yql/core/common_opt/yql_co_simple1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6263,8 +6263,6 @@ void RegisterCoSimpleCallables1(TCallableOptimizerMap& map) {
map["PgLike"] = &ExpandPgLike;
map["PgILike"] = &ExpandPgLike;

map["PgIn"] = &ExpandPgIn;

map["PgBetween"] = &ExpandPgBetween;
map["PgBetweenSym"] = &ExpandPgBetween;

Expand Down
208 changes: 123 additions & 85 deletions ydb/library/yql/core/type_ann/type_ann_pg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ namespace NYql {

namespace NTypeAnnImpl {

const NPg::TTypeDesc& GetTypeDescOfNode(const TExprNodePtr& node)
{
const auto typeId = node->GetTypeAnn()->Cast<TPgExprType>()->GetId();

return NPg::LookupType(typeId);
}

bool IsCastRequired(ui32 fromTypeId, ui32 toTypeId) {
if (toTypeId == fromTypeId) {
return false;
Expand Down Expand Up @@ -76,6 +83,17 @@ TExprNodePtr WrapWithPgCast(TExprNodePtr&& node, ui32 targetTypeId, TExprContext
.Build();
};

TExprNodePtr WrapWithPgCast(TExprNodePtr& node, ui32 targetTypeId, TExprContext& ctx) {
return ctx.Builder(node->Pos())
.Callable("PgCast")
.Add(0, node)
.Callable(1, "PgType")
.Atom(0, NPg::LookupType(targetTypeId).Name)
.Seal()
.Seal()
.Build();
};

TExprNodePtr FindLeftCombinatorOfNthSetItem(const TExprNode* setItems, const TExprNode* setOps, ui32 n) {
TVector<ui32> setItemsStack(setItems->ChildrenSize());
i32 sp = -1;
Expand Down Expand Up @@ -5430,115 +5448,135 @@ IGraphTransformer::TStatus PgLikeWrapper(const TExprNode::TPtr& input, TExprNode
return IGraphTransformer::TStatus::Ok;
}

IGraphTransformer::TStatus PgInWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
if (!EnsureArgsCount(*input, 2, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}
TExprNodePtr BuildUniTypePgIn(TExprNodeList&& args, TContext& ctx) {
auto lhs = args[0];
std::swap(args[0], args.back());
args.pop_back();

auto inputType = input->Child(0)->GetTypeAnn();
ui32 inputTypePg;
bool convertToPg;
if (!ExtractPgType(inputType, inputTypePg, convertToPg, input->Child(0)->Pos(), ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}

if (convertToPg) {
input->ChildRef(0) = ctx.Expr.NewCallable(input->Child(0)->Pos(), "ToPg", { input->ChildPtr(0) });
return IGraphTransformer::TStatus::Repeat;
}
return ctx.Expr.Builder(lhs->Pos())
.Callable("SqlIn")
.List(0)
.Add(std::move(args))
.Seal()
.Add(1, lhs)
.List(2)
.List(0)
.Atom(0, "ansi")
.Seal()
.Seal()
.Seal()
.Build();
}

auto listType = input->Child(1)->GetTypeAnn();
if (listType && listType->GetKind() == ETypeAnnotationKind::EmptyList) {
IGraphTransformer::TStatus PgInWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
if (!EnsureMinArgsCount(*input, 2, ctx.Expr)) {
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()), "IN expects at least one element"));
return IGraphTransformer::TStatus::Error;
}

if (!EnsureListType(*input->Child(1), ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}

auto listItemType = listType->Cast<TListExprType>()->GetItemType();
ui32 itemTypePg;
if (!ExtractPgType(listItemType, itemTypePg, convertToPg, input->Child(1)->Pos(), ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}
TVector<ui32> pgTypes(input->ChildrenSize());
{
TExprNodeList convertedChildren;
convertedChildren.reserve(input->ChildrenSize());
bool convertionRequired = false;
bool hasConvertions = false;
for (size_t i = 0; i < input->ChildrenSize(); ++i) {
const auto child = input->Child(i);
if (!ExtractPgType(child->GetTypeAnn(), pgTypes[i], convertionRequired, child->Pos(), ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}
if (convertionRequired) {
hasConvertions = true;

if (convertToPg) {
output = ctx.Expr.Builder(input->Pos())
.Callable("PgIn")
.Add(0, input->ChildPtr(0))
.Callable(1, "Map")
.Add(0, input->ChildPtr(1))
.Lambda(1)
.Param("x")
.Callable("ToPg")
.Arg(0, "x")
.Seal()
auto convertedChild = ctx.Expr.Builder(child->Pos())
.Callable("ToPg")
.Add(0, child)
.Seal()
.Seal()
.Seal()
.Build();

return IGraphTransformer::TStatus::Repeat;
}

if (itemTypePg && inputTypePg && itemTypePg != inputTypePg) {
if (inputTypePg == NPg::UnknownOid) {

input->ChildRef(0) = WrapWithPgCast(std::move(input->Child(0)), itemTypePg, ctx.Expr);
return IGraphTransformer::TStatus::Repeat;
.Build();
convertedChildren.push_back(std::move(convertedChild));
} else {
convertedChildren.push_back(std::move(child));
}
}
if (itemTypePg == NPg::UnknownOid) {
if (hasConvertions) {
output = ctx.Expr.Builder(input->Pos())
.Callable("PgIn")
.Add(0, input->ChildPtr(0))
.Callable(1, "Map")
.Add(0, input->ChildPtr(1))
.Lambda(1)
.Param("x")
.Callable("PgCast")
.Arg(0, "x")
.Callable(1, "PgType")
.Atom(0, NPg::LookupType(inputTypePg).Name)
.Seal()
.Seal()
.Seal()
.Seal()
.Add(std::move(convertedChildren))
.Seal()
.Build();

return IGraphTransformer::TStatus::Repeat;
}

ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
TStringBuilder() << "Mismatch of types in IN expressions: " <<
NPg::LookupType(inputTypePg).Name << " is not equal to " << NPg::LookupType(itemTypePg).Name));
return IGraphTransformer::TStatus::Error;
}

if (itemTypePg && !listItemType->IsEquatable()) {
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
TStringBuilder() << "Cannot compare items of type: " << NPg::LookupType(itemTypePg).Name));
}
auto posGetter = [&input, &ctx](size_t i) {
return ctx.Expr.GetPosition(input->Child(i)->Pos());
};

if (inputTypePg && !inputType->IsEquatable()) {
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
TStringBuilder() << "Cannot compare items of type: " << NPg::LookupType(inputTypePg).Name));
}
struct TPgListCommonTypeConversion {
ui32 targetType = 0;
TExprNodeList items;
};

bool castRequired = false;
const NPg::TTypeDesc* commonType;
if (NPg::LookupCommonType(pgTypes, posGetter, commonType, castRequired))
{
const auto& lhsTypeId = input->Head().GetTypeAnn()->Cast<TPgExprType>()->GetId();

THashMap<ui32, TPgListCommonTypeConversion> elemsByType;
for (size_t i = 1; i < input->ChildrenSize(); ++i) {
auto& elemsOfType = elemsByType[pgTypes[i]];
if (elemsOfType.items.empty()) {
const NPg::TTypeDesc* elemCommonType;
if (const auto issue = NPg::LookupCommonType({pgTypes[0], pgTypes[i]},
posGetter, elemCommonType))
{
ctx.Expr.AddError(*issue);
return IGraphTransformer::TStatus::Error;
}
elemsOfType.targetType = elemCommonType->TypeId;

elemsOfType.items.push_back((lhsTypeId == elemsOfType.targetType)
? input->HeadPtr()
: WrapWithPgCast(input->HeadPtr(), elemsOfType.targetType, ctx.Expr));
}
const auto rhsItemTypeId = input->Child(i)->GetTypeAnn()->Cast<TPgExprType>()->GetId();
elemsOfType.items.push_back((rhsItemTypeId == elemsOfType.targetType)
? input->Child(i)
: WrapWithPgCast(input->Child(i), elemsOfType.targetType, ctx.Expr));
}
TExprNodeList orClausesOfIn;
orClausesOfIn.reserve(elemsByType.size());

if (!itemTypePg || !inputTypePg) {
for (auto& elemsOfType: elemsByType) {
auto& conversion = elemsOfType.second;
orClausesOfIn.push_back(BuildUniTypePgIn(std::move(conversion.items), ctx));
}
output = ctx.Expr.Builder(input->Pos())
.Callable("Nothing")
.Callable(0, "PgType")
.Atom(0, "bool")
.Seal()
.Callable("Or")
.Add(std::move(orClausesOfIn))
.Seal()
.Build();
} else {
TExprNodeList items;

if (castRequired) {
for (size_t i = 0; i < input->ChildrenSize(); ++i) {
const auto itemTypeId = input->Child(i)->GetTypeAnn()->Cast<TPgExprType>()->GetId();
items.push_back((itemTypeId == commonType->TypeId)
? input->Child(i)
: WrapWithPgCast(input->Child(i), commonType->TypeId, ctx.Expr));
}
}
output = BuildUniTypePgIn(std::move((castRequired) ? items : input->ChildrenList()), ctx);
}

auto result = ctx.Expr.MakeType<TPgExprType>(NPg::LookupType("bool").TypeId);
input->SetTypeAnn(result);
return IGraphTransformer::TStatus::Ok;
output = ctx.Expr.Builder(input->Pos())
.Callable("ToPg")
.Add(0, output)
.Seal()
.Build();
return IGraphTransformer::TStatus::Repeat;
}

IGraphTransformer::TStatus PgBetweenWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
Expand Down
13 changes: 8 additions & 5 deletions ydb/library/yql/sql/pg/pg_sql.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4769,18 +4769,21 @@ class TConverter : public IPGParseEvents {
}

auto lst = CAST_NODE(List, value->rexpr);
TVector<TAstNode*> listItems;
listItems.push_back(A("AsList"));

TVector<TAstNode*> children;
children.reserve(2 + ListLength(lst));

children.push_back(A("PgIn"));
children.push_back(lhs);
for (int item = 0; item < ListLength(lst); ++item) {
auto cell = ParseExpr(ListNodeNth(lst, item), settings);
if (!cell) {
return nullptr;
}

listItems.push_back(cell);
children.push_back(cell);
}

auto ret = L(A("PgIn"), lhs, VL(listItems.data(), listItems.size()));
auto ret = VL(children.data(), children.size());
if (op[0] == '<') {
ret = L(A("PgNot"), ret);
}
Expand Down
Loading

0 comments on commit 3edc507

Please sign in to comment.