Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix type comparison for java [skip ci] #6970

Merged
merged 1 commit into from
Dec 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions java/src/main/java/ai/rapids/cudf/BinaryOperable.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,42 +41,42 @@ public interface BinaryOperable {
static DType implicitConversion(BinaryOperable lhs, BinaryOperable rhs) {
DType a = lhs.getType();
DType b = rhs.getType();
if (a == DType.FLOAT64 || b == DType.FLOAT64) {
if (a.equals(DType.FLOAT64) || b.equals(DType.FLOAT64)) {
return DType.FLOAT64;
}
if (a == DType.FLOAT32 || b == DType.FLOAT32) {
if (a.equals(DType.FLOAT32) || b.equals(DType.FLOAT32)) {
return DType.FLOAT32;
}
if (a == DType.UINT64 || b == DType.UINT64) {
if (a.equals(DType.UINT64) || b.equals(DType.UINT64)) {
return DType.UINT64;
}
if (a == DType.INT64 || b == DType.INT64 ||
a == DType.TIMESTAMP_MILLISECONDS || b == DType.TIMESTAMP_MILLISECONDS ||
a == DType.TIMESTAMP_MICROSECONDS || b == DType.TIMESTAMP_MICROSECONDS ||
a == DType.TIMESTAMP_SECONDS || b == DType.TIMESTAMP_SECONDS ||
a == DType.TIMESTAMP_NANOSECONDS || b == DType.TIMESTAMP_NANOSECONDS) {
if (a.equals(DType.INT64) || b.equals(DType.INT64) ||
a.equals(DType.TIMESTAMP_MILLISECONDS) || b.equals(DType.TIMESTAMP_MILLISECONDS) ||
a.equals(DType.TIMESTAMP_MICROSECONDS) || b.equals(DType.TIMESTAMP_MICROSECONDS) ||
a.equals(DType.TIMESTAMP_SECONDS) || b.equals(DType.TIMESTAMP_SECONDS) ||
a.equals(DType.TIMESTAMP_NANOSECONDS) || b.equals(DType.TIMESTAMP_NANOSECONDS)) {
return DType.INT64;
}
if (a == DType.UINT32 || b == DType.UINT32) {
if (a.equals(DType.UINT32) || b.equals(DType.UINT32)) {
return DType.UINT32;
}
if (a == DType.INT32 || b == DType.INT32 ||
a == DType.TIMESTAMP_DAYS || b == DType.TIMESTAMP_DAYS) {
if (a.equals(DType.INT32) || b.equals(DType.INT32) ||
a.equals(DType.TIMESTAMP_DAYS) || b.equals(DType.TIMESTAMP_DAYS)) {
return DType.INT32;
}
if (a == DType.UINT16 || b == DType.UINT16) {
if (a.equals(DType.UINT16) || b.equals(DType.UINT16)) {
return DType.UINT16;
}
if (a == DType.INT16 || b == DType.INT16) {
if (a.equals(DType.INT16) || b.equals(DType.INT16)) {
return DType.INT16;
}
if (a == DType.UINT8 || b == DType.UINT8) {
if (a.equals(DType.UINT8) || b.equals(DType.UINT8)) {
return DType.UINT8;
}
if (a == DType.INT8 || b == DType.INT8) {
if (a.equals(DType.INT8) || b.equals(DType.INT8)) {
return DType.INT8;
}
if (a == DType.BOOL8 || b == DType.BOOL8) {
if (a.equals(DType.BOOL8) || b.equals(DType.BOOL8)) {
return DType.BOOL8;
}
if (a.isDecimalType() && b.isDecimalType()) {
Expand Down
20 changes: 10 additions & 10 deletions java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ public ColumnVector(DType type, long rows, Optional<Long> nullCount,
super(ColumnVector.initViewHandle(
type, (int)rows, nullCount.orElse(UNKNOWN_NULL_COUNT).intValue(),
dataBuffer, validityBuffer, offsetBuffer, null));
assert type != DType.LIST : "This constructor should not be used for list type";
if (type != DType.STRING) {
assert !type.equals(DType.LIST) : "This constructor should not be used for list type";
if (!type.equals(DType.STRING)) {
assert offsetBuffer == null : "offsets are only supported for STRING";
}
assert (nullCount.isPresent() && nullCount.get() <= Integer.MAX_VALUE)
Expand Down Expand Up @@ -120,7 +120,7 @@ public ColumnVector(DType type, long rows, Optional<Long> nullCount,
super(initViewHandle(type, (int)rows, nullCount.orElse(UNKNOWN_NULL_COUNT).intValue(),
dataBuffer, validityBuffer,
offsetBuffer, childHandles));
if (type != DType.STRING && type != DType.LIST) {
if (!type.equals(DType.STRING) && !type.equals(DType.LIST)) {
assert offsetBuffer == null : "offsets are only supported for STRING, LISTS";
}
assert (nullCount.isPresent() && nullCount.get() <= Integer.MAX_VALUE)
Expand Down Expand Up @@ -393,15 +393,15 @@ public static ColumnVector stringConcatenate(ColumnView[] columns) {
public static ColumnVector stringConcatenate(Scalar separator, Scalar narep, ColumnView[] columns) {
assert columns.length >= 2 : ".stringConcatenate() operation requires at least 2 columns";
assert separator != null : "separator scalar provided may not be null";
assert separator.getType() == DType.STRING : "separator scalar must be a string scalar";
assert separator.getType().equals(DType.STRING) : "separator scalar must be a string scalar";
assert narep != null : "narep scalar provided may not be null";
assert narep.getType() == DType.STRING : "narep scalar must be a string scalar";
assert narep.getType().equals(DType.STRING) : "narep scalar must be a string scalar";
long size = columns[0].getRowCount();
long[] column_views = new long[columns.length];

for(int i = 0; i < columns.length; i++) {
assert columns[i] != null : "Column vectors passed may not be null";
assert columns[i].getType() == DType.STRING : "All columns must be of type string for .cat() operation";
assert columns[i].getType().equals(DType.STRING) : "All columns must be of type string for .cat() operation";
assert columns[i].getRowCount() == size : "Row count mismatch, all columns must have the same number of rows";
column_views[i] = columns[i].getNativeView();
}
Expand All @@ -426,8 +426,8 @@ public static ColumnVector md5Hash(ColumnView... columns) {
assert columns[i] != null : "Column vectors passed may not be null";
assert columns[i].getRowCount() == size : "Row count mismatch, all columns must be the same size";
assert !columns[i].getType().isDurationType() : "Unsupported column type Duration";
assert !columns[i].getType().isTimestamp() : "Unsupported column type Timestamp";
assert !columns[i].getType().isNestedType() || columns[i].getType() == DType.LIST :
assert !columns[i].getType().isTimestampType() : "Unsupported column type Timestamp";
assert !columns[i].getType().isNestedType() || columns[i].getType().equals(DType.LIST) :
"Unsupported nested type column";
columnViews[i] = columns[i].getNativeView();
}
Expand All @@ -452,7 +452,7 @@ public static ColumnVector serial32BitMurmurHash3(int seed, ColumnView columns[]
assert columns[i] != null : "Column vectors passed may not be null";
assert columns[i].getRowCount() == size : "Row count mismatch, all columns must be the same size";
assert !columns[i].getType().isDurationType() : "Unsupported column type Duration";
assert !columns[i].getType().isTimestamp() : "Unsupported column type Timestamp";
assert !columns[i].getType().isTimestampType() : "Unsupported column type Timestamp";
assert !columns[i].getType().isNestedType() : "Unsupported column of nested type";
columnViews[i] = columns[i].getNativeView();
}
Expand Down Expand Up @@ -492,7 +492,7 @@ public static ColumnVector serial32BitMurmurHash3(ColumnView columns[]) {
*/
@Override
public ColumnVector castTo(DType type) {
if (this.type == type) {
if (this.type.equals(type)) {
// Optimization
return incRefCount();
}
Expand Down
Loading