Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
zanmato1984 committed Oct 14, 2024
1 parent 27e56f0 commit 945505a
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 6 deletions.
120 changes: 115 additions & 5 deletions cpp/src/arrow/compute/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "arrow/compute/exec_internal.h"
#include "arrow/compute/expression_internal.h"
#include "arrow/compute/function_internal.h"
#include "arrow/compute/special_form.h"
#include "arrow/compute/util.h"
#include "arrow/io/memory.h"
#include "arrow/ipc/reader.h"
Expand Down Expand Up @@ -55,6 +56,13 @@ void Expression::Call::ComputeHash() {
}
}

void Expression::Special::ComputeHash() {
hash = std::hash<std::string>{}(special_form->name);
for (const auto& arg : arguments) {
arrow::internal::hash_combine(hash, arg.hash());
}
}

Expression::Expression(Call call) {
call.ComputeHash();
impl_ = std::make_shared<Impl>(std::move(call));
Expand All @@ -66,6 +74,9 @@ Expression::Expression(Datum literal)
Expression::Expression(Parameter parameter)
: impl_(std::make_shared<Impl>(std::move(parameter))) {}

Expression::Expression(Special special)
: impl_(std::make_shared<Impl>(std::move(special))) {}

Expression literal(Datum lit) { return Expression(std::move(lit)); }

Expression field_ref(FieldRef ref) {
Expand All @@ -81,6 +92,13 @@ Expression call(std::string function, std::vector<Expression> arguments,
return Expression(std::move(call));
}

Expression if_else_special(Expression cond, Expression if_true, Expression if_false) {
Expression::Special special;
special.arguments = {std::move(cond), std::move(if_true), std::move(if_false)};
special.special_form = std::make_shared<IfElseSpecialForm>();
return Expression(std::move(special));
}

const Datum* Expression::literal() const {
if (impl_ == nullptr) return nullptr;

Expand All @@ -106,6 +124,12 @@ const Expression::Call* Expression::call() const {
return std::get_if<Call>(impl_.get());
}

const Expression::Special* Expression::special() const {
if (impl_ == nullptr) return nullptr;

return std::get_if<Special>(impl_.get());
}

const DataType* Expression::type() const {
if (impl_ == nullptr) return nullptr;

Expand All @@ -117,7 +141,11 @@ const DataType* Expression::type() const {
return parameter->type.type;
}

return CallNotNull(*this)->type.type;
if (const Call* call = this->call()) {
return call->type.type;
}

return SpecialNotNull(*this)->type.type;
}

namespace {
Expand Down Expand Up @@ -268,7 +296,11 @@ size_t Expression::hash() const {
return ref->hash();
}

return CallNotNull(*this)->hash;
if (auto call = call()) {
return call->hash;
}

return SpecialNotNull(*this)->hash;
}

bool Expression::IsBound() const {
Expand Down Expand Up @@ -601,6 +633,78 @@ Result<Expression> BindNonRecursive(Expression::Call call, bool insert_implicit_
return Expression(std::move(call));
}

Result<Expression> BindNonRecursive(Expression::Special special, bool insert_implicit_casts,
compute::ExecContext* exec_context) {
DCHECK(std::all_of(special.arguments.begin(), special.arguments.end(),
[](const Expression& argument) { return argument.IsBound(); }));

std::vector<TypeHolder> types = GetTypes(special.arguments);
ARROW_ASSIGN_OR_RAISE(call.function, GetFunction(call, exec_context));

auto FinishBind = [&] {
compute::KernelContext kernel_context(exec_context, call.kernel);
if (call.kernel->init) {
const FunctionOptions* options =
call.options ? call.options.get() : call.function->default_options();
ARROW_ASSIGN_OR_RAISE(
call.kernel_state,
call.kernel->init(&kernel_context, {call.kernel, types, options}));

kernel_context.SetState(call.kernel_state.get());
}

ARROW_ASSIGN_OR_RAISE(
call.type, call.kernel->signature->out_type().Resolve(&kernel_context, types));
return Status::OK();
};

// First try and bind exactly
Result<const Kernel*> maybe_exact_match = call.function->DispatchExact(types);
if (maybe_exact_match.ok()) {
call.kernel = *maybe_exact_match;
if (FinishBind().ok()) {
return Expression(std::move(call));
}
}

if (!insert_implicit_casts) {
return maybe_exact_match.status();
}

// If exact binding fails, and we are allowed to cast, then prefer casting literals
// first. Since DispatchBest generally prefers up-casting the best way to do this is
// first down-cast the literals as much as possible
types = GetTypesWithSmallestLiteralRepresentation(call.arguments);
ARROW_ASSIGN_OR_RAISE(call.kernel, call.function->DispatchBest(&types));

for (size_t i = 0; i < types.size(); ++i) {
if (types[i] == call.arguments[i].type()) continue;

if (const Datum* lit = call.arguments[i].literal()) {
ARROW_ASSIGN_OR_RAISE(Datum new_lit, compute::Cast(*lit, types[i].GetSharedPtr()));
call.arguments[i] = literal(std::move(new_lit));
continue;
}

// construct an implicit cast Expression with which to replace this argument
Expression::Call implicit_cast;
implicit_cast.function_name = "cast";
implicit_cast.arguments = {std::move(call.arguments[i])};

// TODO(wesm): Use TypeHolder in options
implicit_cast.options = std::make_shared<compute::CastOptions>(
compute::CastOptions::Safe(types[i].GetSharedPtr()));

ARROW_ASSIGN_OR_RAISE(
call.arguments[i],
BindNonRecursive(std::move(implicit_cast),
/*insert_implicit_casts=*/false, exec_context));
}

RETURN_NOT_OK(FinishBind());
return Expression(std::move(call));
}

template <typename TypeOrSchema>
Result<Expression> BindImpl(Expression expr, const TypeOrSchema& in,
compute::ExecContext* exec_context) {
Expand All @@ -622,12 +726,18 @@ Result<Expression> BindImpl(Expression expr, const TypeOrSchema& in,
return Expression{std::move(param)};
}

auto call = *CallNotNull(expr);
if (const Call* call = expr.call()) {
for (auto& argument : call.arguments) {
ARROW_ASSIGN_OR_RAISE(argument, BindImpl(std::move(argument), in, exec_context));
}
return BindNonRecursive(std::move(call),
/*insert_implicit_casts=*/true, exec_context);
}

auto special = *SpecialNotNull(expr);
for (auto& argument : call.arguments) {
ARROW_ASSIGN_OR_RAISE(argument, BindImpl(std::move(argument), in, exec_context));
}
return BindNonRecursive(std::move(call),
/*insert_implicit_casts=*/true, exec_context);
}

} // namespace
Expand Down
21 changes: 20 additions & 1 deletion cpp/src/arrow/compute/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
namespace arrow {
namespace compute {

class SpecialForm;

/// \defgroup expression-core Expressions to describe data transformations
///
/// @{
Expand Down Expand Up @@ -60,6 +62,18 @@ class ARROW_EXPORT Expression {
void ComputeHash();
};

struct Special {
std::vector<Expression> arguments;
std::shared_ptr<SpecialForm> special_form;
// Cached hash value
size_t hash;

// post-Bind properties:
TypeHolder type;

void ComputeHash();
};

std::string ToString() const;
bool Equals(const Expression& other) const;
size_t hash() const;
Expand Down Expand Up @@ -112,6 +126,8 @@ class ARROW_EXPORT Expression {
const Datum* literal() const;
/// Access a FieldRef or return nullptr if this expression is not a field_ref
const FieldRef* field_ref() const;
/// Access a FieldRef or return nullptr if this expression is not a field_ref
const Special* special() const;

/// The type to which this expression will evaluate
const DataType* type() const;
Expand All @@ -131,9 +147,10 @@ class ARROW_EXPORT Expression {
explicit Expression(Call call);
explicit Expression(Datum literal);
explicit Expression(Parameter parameter);
explicit Expression(Special special);

private:
using Impl = std::variant<Datum, Parameter, Call>;
using Impl = std::variant<Datum, Parameter, Call, Special>;
std::shared_ptr<Impl> impl_;

ARROW_FRIEND_EXPORT friend bool Identical(const Expression& l, const Expression& r);
Expand Down Expand Up @@ -169,6 +186,8 @@ Expression call(std::string function, std::vector<Expression> arguments,
std::make_shared<Options>(std::move(options)));
}

Expression if_else_special(Expression cond, Expression if_true, Expression if_false);

/// Assemble a list of all fields referenced by an Expression at any depth.
ARROW_EXPORT
std::vector<FieldRef> FieldsInExpression(const Expression&);
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/compute/expression_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ inline const Expression::Call* CallNotNull(const Expression& expr) {
return call;
}

inline const Expression::Special* SpecialNotNull(const Expression& expr) {
auto special = expr.special();
DCHECK_NE(special, nullptr);
return special;
}

inline std::vector<TypeHolder> GetTypes(const std::vector<Expression>& exprs) {
std::vector<TypeHolder> types(exprs.size());
for (size_t i = 0; i < exprs.size(); ++i) {
Expand Down
18 changes: 18 additions & 0 deletions cpp/src/arrow/compute/special_form.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#include "arrow/compute/expression.h"

namespace arrow::compute {

class SpecialForm {
public:
SpecialForm(const std::string& name) {}

public:
const std::string name;
};

class IfElseSpecialForm : public SpecialForm {
public:
IfElseSpecialForm() : SpecialForm("if_else") {}
};

} // namespace arrow::compute

0 comments on commit 945505a

Please sign in to comment.