Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add overflow detection to IntervalYearMonthOperators #24617

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.type.AbstractIntType;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.BlockIndex;
import com.facebook.presto.spi.function.BlockPosition;
import com.facebook.presto.spi.function.IsNull;
Expand All @@ -42,6 +43,7 @@
import static com.facebook.presto.common.function.OperatorType.NEGATION;
import static com.facebook.presto.common.function.OperatorType.NOT_EQUAL;
import static com.facebook.presto.common.function.OperatorType.SUBTRACT;
import static com.facebook.presto.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE;
import static com.facebook.presto.type.IntervalYearMonthType.INTERVAL_YEAR_MONTH;
import static io.airlift.slice.Slices.utf8Slice;
import static java.lang.Math.toIntExact;
Expand All @@ -56,49 +58,81 @@ private IntervalYearMonthOperators()
@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH)
public static long add(@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long left, @SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long right)
{
return left + right;
try {
return Math.addExact((int) left, (int) right);
}
catch (ArithmeticException e) {
throw new PrestoException(NUMERIC_VALUE_OUT_OF_RANGE, "Overflow adding interval year-month values: " + left + " + " + right);
}
}

@ScalarOperator(SUBTRACT)
@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH)
public static long subtract(@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long left, @SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long right)
{
return left - right;
try {
return Math.subtractExact((int) left, (int) right);
}
catch (ArithmeticException e) {
throw new PrestoException(NUMERIC_VALUE_OUT_OF_RANGE, "Overflow subtracting interval year-month values: " + left + " - " + right);
}
}

@ScalarOperator(MULTIPLY)
@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH)
public static long multiplyByBigint(@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long left, @SqlType(StandardTypes.BIGINT) long right)
public static long multiplyByInteger(@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long left, @SqlType(StandardTypes.INTEGER) long right)
{
return left * right;
try {
return Math.multiplyExact((int) left, (int) right);
}
catch (ArithmeticException e) {
throw new PrestoException(NUMERIC_VALUE_OUT_OF_RANGE, "Overflow multiplying interval year-month value by integer: " + left + " * " + right);
}
}

@ScalarOperator(MULTIPLY)
@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH)
public static long multiplyByDouble(@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long left, @SqlType(StandardTypes.DOUBLE) double right)
{
return (long) (left * right);
long result = (long) (left * right);
if (result < Integer.MIN_VALUE || result > Integer.MAX_VALUE) {
throw new PrestoException(NUMERIC_VALUE_OUT_OF_RANGE, "Overflow multiplying interval year-month value by double: " + left + " * " + right);
}
return result;
}

@ScalarOperator(MULTIPLY)
@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH)
public static long bigintMultiply(@SqlType(StandardTypes.BIGINT) long left, @SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long right)
public static long integerMultiply(@SqlType(StandardTypes.INTEGER) long left, @SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long right)
{
return left * right;
try {
return Math.multiplyExact((int) left, (int) right);
}
catch (ArithmeticException e) {
throw new PrestoException(NUMERIC_VALUE_OUT_OF_RANGE, "Overflow multiplying integer by interval year-month value: " + left + " * " + right);
}
}

@ScalarOperator(MULTIPLY)
@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH)
public static long doubleMultiply(@SqlType(StandardTypes.DOUBLE) double left, @SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long right)
{
return (long) (left * right);
long result = (long) (left * right);
if (result < Integer.MIN_VALUE || result > Integer.MAX_VALUE) {
throw new PrestoException(NUMERIC_VALUE_OUT_OF_RANGE, "Overflow multiplying double by interval year-month value: " + left + " * " + right);
}
return result;
}

@ScalarOperator(DIVIDE)
@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH)
public static long divideByDouble(@SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long left, @SqlType(StandardTypes.DOUBLE) double right)
{
return (long) (left / right);
long result = (long) (left / right);
if (result < Integer.MIN_VALUE || result > Integer.MAX_VALUE) {
throw new PrestoException(NUMERIC_VALUE_OUT_OF_RANGE, "Overflow dividing interval year-month value by double: " + left + " / " + right);
}
return result;
}

@ScalarOperator(NEGATION)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public class TestIntervalYearMonth
extends AbstractTestFunctions
{
private static final int MAX_SHORT = Short.MAX_VALUE;
private static final long MAX_INT_PLUS_1 = Integer.MAX_VALUE + 1L;

@Test
public void testObject()
Expand Down Expand Up @@ -74,6 +75,7 @@ public void testInvalidLiteral()
assertInvalidFunction("INTERVAL '124-X' YEAR TO MONTH", "Invalid INTERVAL YEAR TO MONTH value: 124-X");
assertInvalidFunction("INTERVAL '124--30' YEAR TO MONTH", "Invalid INTERVAL YEAR TO MONTH value: 124--30");
assertInvalidFunction("INTERVAL '--124--30' YEAR TO MONTH", "Invalid INTERVAL YEAR TO MONTH value: --124--30");
assertInvalidFunction(format("INTERVAL '%s' MONTH", MAX_INT_PLUS_1), "Invalid INTERVAL MONTH value: " + MAX_INT_PLUS_1);
}

@Test
Expand All @@ -82,6 +84,7 @@ public void testAdd()
assertFunction("INTERVAL '3' MONTH + INTERVAL '3' MONTH", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(6));
assertFunction("INTERVAL '6' YEAR + INTERVAL '6' YEAR", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(12 * 12));
assertFunction("INTERVAL '3' MONTH + INTERVAL '6' YEAR", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth((6 * 12) + (3)));
assertNumericOverflow(format("INTERVAL '%s' MONTH + INTERVAL '1' MONTH", Integer.MAX_VALUE), format("Overflow adding interval year-month values: %s + 1", Integer.MAX_VALUE));
}

@Test
Expand All @@ -90,6 +93,7 @@ public void testSubtract()
assertFunction("INTERVAL '6' MONTH - INTERVAL '3' MONTH", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(3));
assertFunction("INTERVAL '9' YEAR - INTERVAL '6' YEAR", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(3 * 12));
assertFunction("INTERVAL '3' MONTH - INTERVAL '6' YEAR", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth((3) - (6 * 12)));
assertNumericOverflow(format("-INTERVAL '%s' MONTH - INTERVAL '2' MONTH", Integer.MAX_VALUE), format("Overflow subtracting interval year-month values: -%s - 2", Integer.MAX_VALUE));
}

@Test
Expand All @@ -104,6 +108,13 @@ public void testMultiply()
assertFunction("2 * INTERVAL '6' YEAR", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(12 * 12));
assertFunction("INTERVAL '1' YEAR * 2.5", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth((int) (2.5 * 12)));
assertFunction("2.5 * INTERVAL '1' YEAR", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth((int) (2.5 * 12)));

assertNumericOverflow(format("INTERVAL '%s' MONTH * 2", Integer.MAX_VALUE), format("Overflow multiplying interval year-month value by integer: %s * 2", Integer.MAX_VALUE));
assertNumericOverflow(format("2 * INTERVAL '%s' MONTH", Integer.MAX_VALUE), format("Overflow multiplying integer by interval year-month value: 2 * %s", Integer.MAX_VALUE));
assertNumericOverflow(format("INTERVAL '%s' MONTH * 2.0", Integer.MAX_VALUE), format("Overflow multiplying interval year-month value by double: %s * 2.0", Integer.MAX_VALUE));
assertNumericOverflow(format("DOUBLE '2' * INTERVAL '%s' MONTH", Integer.MAX_VALUE), format("Overflow multiplying double by interval year-month value: 2.0 * %s", Integer.MAX_VALUE));
assertNumericOverflow(format("INTERVAL '2' YEAR * %s", (long) Integer.MAX_VALUE + 1), format("Overflow multiplying interval year-month value by double: 24 * %s", (double) ((long) Integer.MAX_VALUE + 1)));
assertNumericOverflow(format("%s * INTERVAL '2' YEAR", (long) Integer.MAX_VALUE + 1), format("Overflow multiplying double by interval year-month value: %s * 24", (double) ((long) Integer.MAX_VALUE + 1)));
}

@Test
Expand All @@ -114,6 +125,8 @@ public void testDivide()

assertFunction("INTERVAL '3' YEAR / 2", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(18));
assertFunction("INTERVAL '4' YEAR / 4.8", INTERVAL_YEAR_MONTH, new SqlIntervalYearMonth(10));

assertNumericOverflow(format("INTERVAL '%s' MONTH / 0.5", Integer.MAX_VALUE), format("Overflow dividing interval year-month value by double: %s / 0.5", Integer.MAX_VALUE));
}

@Test
Expand Down
Loading