diff --git a/velox/expression/tests/ExpressionFuzzer.cpp b/velox/expression/tests/ExpressionFuzzer.cpp index 9ec67bd4b414..fdd249e658a0 100644 --- a/velox/expression/tests/ExpressionFuzzer.cpp +++ b/velox/expression/tests/ExpressionFuzzer.cpp @@ -44,26 +44,22 @@ class DecimalArgGeneratorBase : public ArgGenerator { const TypePtr& returnType, FuzzerGenerator& rng) override { auto inputs = findInputs(returnType, rng); - if (inputs.a == nullptr || inputs.b == nullptr) { - return {}; + for (const auto& input : inputs) { + if (input == nullptr) { + return {}; + } } - - return {std::move(inputs.a), std::move(inputs.b)}; + return std::move(inputs); } protected: // Compute result type for all possible pairs of decimal input types. Store // the results in 'inputs_' maps keyed by return type. void initialize() { - std::vector allTypes; - for (auto p = 1; p < 38; ++p) { - for (auto s = 0; s <= p; ++s) { - allTypes.push_back(DECIMAL(p, s)); - } - } - - for (auto& a : allTypes) { - for (auto& b : allTypes) { + // By default, the result type is considered to be calculated from two input + // decimal types. + for (auto& a : allTypes_) { + for (auto& b : allTypes_) { auto [p1, s1] = getDecimalPrecisionScale(*a); auto [p2, s2] = getDecimalPrecisionScale(*b); @@ -74,10 +70,7 @@ class DecimalArgGeneratorBase : public ArgGenerator { } } - struct Inputs { - TypePtr a; - TypePtr b; - }; + using Inputs = std::vector; // Return randomly selected pair of input types that produce the specified // result type. @@ -95,10 +88,22 @@ class DecimalArgGeneratorBase : public ArgGenerator { // Given precisions and scales of the inputs, return precision and scale of // the result. - virtual std::optional> - toReturnType(int p1, int s1, int p2, int s2) = 0; + virtual std::optional> toReturnType(...) = 0; std::map, std::vector> inputs_; + + static const std::vector allTypes_ = generateAllTypes(); + + private: + std::vector generateAllTypes() { + std::vector allTypes; + for (auto p = 1; p < 38; ++p) { + for (auto s = 0; s <= p; ++s) { + allTypes.push_back(DECIMAL(p, s)); + } + } + return allTypes; + } }; class PlusMinusArgGenerator : public DecimalArgGeneratorBase { @@ -108,8 +113,15 @@ class PlusMinusArgGenerator : public DecimalArgGeneratorBase { } protected: - std::optional> - toReturnType(int p1, int s1, int p2, int s2) override { + std::optional> toReturnType(...) override { + va_list args; + va_start(args, 4); + int p1 = va_arg(args, int); + int s1 = va_arg(args, int); + int p2 = va_arg(args, int); + int s2 = va_arg(args, int); + va_end(args); + auto s = std::max(s1, s2); auto p = std::min(38, std::max(p1 - s1, p2 - s2) + 1 + s); return {{p, s}}; @@ -123,8 +135,15 @@ class MultiplyArgGenerator : public DecimalArgGeneratorBase { } protected: - std::optional> - toReturnType(int p1, int s1, int p2, int s2) override { + std::optional> toReturnType(...) override { + va_list args; + va_start(args, 4); + int p1 = va_arg(args, int); + int s1 = va_arg(args, int); + int p2 = va_arg(args, int); + int s2 = va_arg(args, int); + va_end(args); + if (s1 + s2 > 38) { return std::nullopt; } @@ -142,8 +161,15 @@ class DivideArgGenerator : public DecimalArgGeneratorBase { } protected: - std::optional> - toReturnType(int p1, int s1, int p2, int s2) override { + std::optional> toReturnType(...) override { + va_list args; + va_start(args, 4); + int p1 = va_arg(args, int); + int s1 = va_arg(args, int); + int p2 = va_arg(args, int); + int s2 = va_arg(args, int); + va_end(args); + if (s1 + s2 > 38) { return std::nullopt; } @@ -785,7 +811,7 @@ ExpressionFuzzer::ExpressionFuzzer( argTypes = typeFuzzer.argumentTypes(); if (!isDeterministic(function.first, argTypes)) { LOG(WARNING) << "Skipping non-deterministic function: " - << function.first << signature->toString(); + << function.first << signature->toString(); continue; } }