Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Decimal DIV changes in cudf [skip ci] #7527

Merged
merged 5 commits into from
Mar 16, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 27 additions & 23 deletions java/src/main/java/ai/rapids/cudf/BinaryOperable.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public interface BinaryOperable {
* with scale=0 as scale is required. Dtype is discarded for binary operations for decimal
* types in cudf as a new DType is created for output type with the new scale.
*/
static DType implicitConversion(BinaryOperable lhs, BinaryOperable rhs) {
static DType implicitConversion(BinaryOp op, BinaryOperable lhs, BinaryOperable rhs) {
DType a = lhs.getType();
DType b = rhs.getType();
if (a.equals(DType.FLOAT64) || b.equals(DType.FLOAT64)) {
Expand Down Expand Up @@ -86,13 +86,15 @@ static DType implicitConversion(BinaryOperable lhs, BinaryOperable rhs) {
int scale = 0;
if (a.typeId == DType.DTypeEnum.DECIMAL32) {
if (b.typeId == DType.DTypeEnum.DECIMAL32) {
return DType.create(DType.DTypeEnum.DECIMAL32, scale);
return DType.create(DType.DTypeEnum.DECIMAL32,
ColumnView.getFixedPointOutpuScale(op, lhs.getType(), rhs.getType()));
} else {
throw new IllegalArgumentException("Both columns must be of the same fixed_point type");
}
} else if (a.typeId == DType.DTypeEnum.DECIMAL64) {
if (b.typeId == DType.DTypeEnum.DECIMAL64) {
return DType.create(DType.DTypeEnum.DECIMAL64, scale);
return DType.create(DType.DTypeEnum.DECIMAL64,
ColumnView.getFixedPointOutpuScale(op, lhs.getType(), rhs.getType()));
} else {
throw new IllegalArgumentException("Both columns must be of the same fixed_point type");
}
Expand Down Expand Up @@ -128,7 +130,7 @@ default ColumnVector add(BinaryOperable rhs, DType outType) {
* Add + operator. this + rhs
*/
default ColumnVector add(BinaryOperable rhs) {
return add(rhs, implicitConversion(this, rhs));
return add(rhs, implicitConversion(BinaryOp.ADD, this, rhs));
}

/**
Expand All @@ -144,7 +146,7 @@ default ColumnVector sub(BinaryOperable rhs, DType outType) {
* Subtract one vector from another. this - rhs
*/
default ColumnVector sub(BinaryOperable rhs) {
return sub(rhs, implicitConversion(this, rhs));
return sub(rhs, implicitConversion(BinaryOp.SUB, this, rhs));
}

/**
Expand All @@ -160,7 +162,7 @@ default ColumnVector mul(BinaryOperable rhs, DType outType) {
* Multiply two vectors together. this * rhs
*/
default ColumnVector mul(BinaryOperable rhs) {
return mul(rhs, implicitConversion(this, rhs));
return mul(rhs, implicitConversion(BinaryOp.MUL, this, rhs));
}

/**
Expand All @@ -176,7 +178,7 @@ default ColumnVector div(BinaryOperable rhs, DType outType) {
* Divide one vector by another. this / rhs
*/
default ColumnVector div(BinaryOperable rhs) {
return div(rhs, implicitConversion(this, rhs));
return div(rhs, implicitConversion(BinaryOp.DIV, this, rhs));
}

/**
Expand All @@ -192,7 +194,7 @@ default ColumnVector trueDiv(BinaryOperable rhs, DType outType) {
* (double)this / (double)rhs
*/
default ColumnVector trueDiv(BinaryOperable rhs) {
return trueDiv(rhs, implicitConversion(this, rhs));
return trueDiv(rhs, implicitConversion(BinaryOp.TRUE_DIV, this, rhs));
}

/**
Expand All @@ -208,7 +210,7 @@ default ColumnVector floorDiv(BinaryOperable rhs, DType outType) {
* Math.floor(this/rhs)
*/
default ColumnVector floorDiv(BinaryOperable rhs) {
return floorDiv(rhs, implicitConversion(this, rhs));
return floorDiv(rhs, implicitConversion(BinaryOp.FLOOR_DIV, this, rhs));
}

/**
Expand All @@ -224,7 +226,7 @@ default ColumnVector mod(BinaryOperable rhs, DType outType) {
* this % rhs
*/
default ColumnVector mod(BinaryOperable rhs) {
return mod(rhs, implicitConversion(this, rhs));
return mod(rhs, implicitConversion(BinaryOp.MOD, this, rhs));
}

/**
Expand All @@ -240,7 +242,7 @@ default ColumnVector pow(BinaryOperable rhs, DType outType) {
* Math.pow(this, rhs)
*/
default ColumnVector pow(BinaryOperable rhs) {
return pow(rhs, implicitConversion(this, rhs));
return pow(rhs, implicitConversion(BinaryOp.POW, this, rhs));
}

/**
Expand Down Expand Up @@ -338,7 +340,7 @@ default ColumnVector bitAnd(BinaryOperable rhs, DType outType) {
* Bit wise and (&). this & rhs
*/
default ColumnVector bitAnd(BinaryOperable rhs) {
return bitAnd(rhs, implicitConversion(this, rhs));
return bitAnd(rhs, implicitConversion(BinaryOp.BITWISE_AND, this, rhs));
}

/**
Expand All @@ -352,7 +354,7 @@ default ColumnVector bitOr(BinaryOperable rhs, DType outType) {
* Bit wise or (|). this | rhs
*/
default ColumnVector bitOr(BinaryOperable rhs) {
return bitOr(rhs, implicitConversion(this, rhs));
return bitOr(rhs, implicitConversion(BinaryOp.BITWISE_OR, this, rhs));
}

/**
Expand All @@ -366,7 +368,7 @@ default ColumnVector bitXor(BinaryOperable rhs, DType outType) {
* Bit wise xor (^). this ^ rhs
*/
default ColumnVector bitXor(BinaryOperable rhs) {
return bitXor(rhs, implicitConversion(this, rhs));
return bitXor(rhs, implicitConversion(BinaryOp.BITWISE_XOR, this, rhs));
}

/**
Expand All @@ -380,7 +382,7 @@ default ColumnVector and(BinaryOperable rhs, DType outType) {
* Logical and (&&). this && rhs
*/
default ColumnVector and(BinaryOperable rhs) {
return and(rhs, implicitConversion(this, rhs));
return and(rhs, implicitConversion(BinaryOp.LOGICAL_AND, this, rhs));
}

/**
Expand All @@ -394,7 +396,7 @@ default ColumnVector or(BinaryOperable rhs, DType outType) {
* Logical or (||). this || rhs
*/
default ColumnVector or(BinaryOperable rhs) {
return or(rhs, implicitConversion(this, rhs));
return or(rhs, implicitConversion(BinaryOp.LOGICAL_OR, this, rhs));
}

/**
Expand All @@ -421,7 +423,7 @@ default ColumnVector shiftLeft(BinaryOperable shiftBy, DType outType) {
* with this[i] << shiftBy
*/
default ColumnVector shiftLeft(BinaryOperable shiftBy) {
return shiftLeft(shiftBy, implicitConversion(this, shiftBy));
return shiftLeft(shiftBy, implicitConversion(BinaryOp.SHIFT_LEFT, this, shiftBy));
}

/**
Expand All @@ -447,7 +449,7 @@ default ColumnVector shiftRight(BinaryOperable shiftBy, DType outType) {
* with this[i] >> shiftBy
*/
default ColumnVector shiftRight(BinaryOperable shiftBy) {
return shiftRight(shiftBy, implicitConversion(this, shiftBy));
return shiftRight(shiftBy, implicitConversion(BinaryOp.SHIFT_RIGHT, this, shiftBy));
}

/**
Expand Down Expand Up @@ -475,7 +477,8 @@ default ColumnVector shiftRightUnsigned(BinaryOperable shiftBy, DType outType) {
* with this[i] >>> shiftBy
*/
default ColumnVector shiftRightUnsigned(BinaryOperable shiftBy) {
return shiftRightUnsigned(shiftBy, implicitConversion(this, shiftBy));
return shiftRightUnsigned(shiftBy, implicitConversion(BinaryOp.SHIFT_RIGHT_UNSIGNED, this,
shiftBy));
}

/**
Expand Down Expand Up @@ -505,7 +508,7 @@ default ColumnVector arctan2(BinaryOperable xCoordinate, DType outType) {
* in radians, between the positive x axis and the ray to the point (x, y) ≠ (0, 0).
*/
default ColumnVector arctan2(BinaryOperable xCoordinate) {
return arctan2(xCoordinate, implicitConversion(this, xCoordinate));
return arctan2(xCoordinate, implicitConversion(BinaryOp.ATAN2, this, xCoordinate));
}

/**
Expand All @@ -529,7 +532,7 @@ default ColumnVector pmod(BinaryOperable rhs, DType outputType) {
*
*/
default ColumnVector pmod(BinaryOperable rhs) {
return pmod(rhs, implicitConversion(this, rhs));
return pmod(rhs, implicitConversion(BinaryOp.PMOD, this, rhs));
}

/**
Expand Down Expand Up @@ -557,7 +560,7 @@ default ColumnVector maxNullAware(BinaryOperable rhs, DType outType) {
* Returns the max non null value.
*/
default ColumnVector maxNullAware(BinaryOperable rhs) {
return maxNullAware(rhs, implicitConversion(this, rhs));
return maxNullAware(rhs, implicitConversion(BinaryOp.NULL_MAX, this, rhs));
}

/**
Expand All @@ -571,6 +574,7 @@ default ColumnVector minNullAware(BinaryOperable rhs, DType outType) {
* Returns the min non null value.
*/
default ColumnVector minNullAware(BinaryOperable rhs) {
return minNullAware(rhs, implicitConversion(this, rhs));
return minNullAware(rhs, implicitConversion(BinaryOp.NULL_MIN, this, rhs));
}

}
7 changes: 7 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ public final long getNativeView() {
return viewHandle;
}

public static int getFixedPointOutpuScale(BinaryOp op, DType lhsType, DType rhsType) {
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
jlowe marked this conversation as resolved.
Show resolved Hide resolved
assert (lhsType.isDecimalType() && rhsType.isDecimalType());
return fixedPointOutputScale(op.nativeId, lhsType.getScale(), rhsType.getScale());
}

private static native int fixedPointOutputScale(int op, int lhsScale, int rhsScale);

public final DType getType() {
return type;
}
Expand Down
10 changes: 10 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
#include <cudf/lists/lists_column_view.hpp>
#include <cudf/structs/structs_column_view.hpp>
#include <map_lookup.hpp>
#include "cudf/types.hpp"

#include "cudf_jni_apis.hpp"
#include "dtype_utils.hpp"
Expand Down Expand Up @@ -1026,6 +1027,15 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_binaryOpVV(JNIEnv *env, j
CATCH_STD(env, 0);
}

JNIEXPORT jint JNICALL Java_ai_rapids_cudf_ColumnView_fixedPointOutputScale(JNIEnv *env, jclass, jint int_op,
jint lhs_scale, jint rhs_scale) {
jlowe marked this conversation as resolved.
Show resolved Hide resolved
try {
// we just return the scale as the types will be the same as the lhs input
return cudf::binary_operation_fixed_point_scale(static_cast<cudf::binary_operator>(int_op), lhs_scale, rhs_scale);
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_binaryOpVS(JNIEnv *env, jclass,
jlong lhs_view, jlong rhs_ptr,
jint int_op, jint out_dtype,
Expand Down