diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java index 0e4df086fb..31e958fa3b 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java @@ -500,26 +500,30 @@ private static DefaultFunctionResolver truncate() { FunctionDSL.impl( FunctionDSL.nullMissingHandling( (x, y) -> new ExprLongValue( - new BigDecimal(x.integerValue()).setScale(y.integerValue(), - RoundingMode.DOWN).longValue())), + BigDecimal.valueOf(x.integerValue()).setScale(y.integerValue(), + x.integerValue() > 0 ? RoundingMode.FLOOR : RoundingMode.CEILING) + .longValue())), LONG, INTEGER, INTEGER), FunctionDSL.impl( FunctionDSL.nullMissingHandling( (x, y) -> new ExprLongValue( - new BigDecimal(x.integerValue()).setScale(y.integerValue(), - RoundingMode.DOWN).longValue())), + BigDecimal.valueOf(x.longValue()).setScale(y.integerValue(), + x.longValue() > 0 ? RoundingMode.FLOOR : RoundingMode.CEILING) + .longValue())), LONG, LONG, INTEGER), FunctionDSL.impl( FunctionDSL.nullMissingHandling( (x, y) -> new ExprDoubleValue( - new BigDecimal(x.floatValue()).setScale(y.integerValue(), - RoundingMode.DOWN).doubleValue())), + BigDecimal.valueOf(x.floatValue()).setScale(y.integerValue(), + x.floatValue() > 0 ? RoundingMode.FLOOR : RoundingMode.CEILING) + .doubleValue())), DOUBLE, FLOAT, INTEGER), FunctionDSL.impl( FunctionDSL.nullMissingHandling( (x, y) -> new ExprDoubleValue( - new BigDecimal(x.doubleValue()).setScale(y.integerValue(), - RoundingMode.DOWN).doubleValue())), + BigDecimal.valueOf(x.doubleValue()).setScale(y.integerValue(), + x.doubleValue() > 0 ? RoundingMode.FLOOR : RoundingMode.CEILING) + .doubleValue())), DOUBLE, DOUBLE, INTEGER)); } diff --git a/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java b/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java index 3d7cdaeb41..a3c053c585 100644 --- a/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunctionTest.java @@ -192,12 +192,12 @@ public void ceil_int_value(Integer value) { assertThat( ceil.valueOf(valueEnv()), allOf(hasType(INTEGER), hasValue((int) Math.ceil(value)))); - assertEquals(String.format("ceil(%s)", value.toString()), ceil.toString()); + assertEquals(String.format("ceil(%s)", value), ceil.toString()); FunctionExpression ceiling = DSL.ceiling(DSL.literal(value)); assertThat( ceiling.valueOf(valueEnv()), allOf(hasType(INTEGER), hasValue((int) Math.ceil(value)))); - assertEquals(String.format("ceiling(%s)", value.toString()), ceiling.toString()); + assertEquals(String.format("ceiling(%s)", value), ceiling.toString()); } /** @@ -209,12 +209,12 @@ public void ceil_long_value(Long value) { FunctionExpression ceil = DSL.ceil(DSL.literal(value)); assertThat( ceil.valueOf(valueEnv()), allOf(hasType(INTEGER), hasValue((int) Math.ceil(value)))); - assertEquals(String.format("ceil(%s)", value.toString()), ceil.toString()); + assertEquals(String.format("ceil(%s)", value), ceil.toString()); FunctionExpression ceiling = DSL.ceiling(DSL.literal(value)); assertThat( ceiling.valueOf(valueEnv()), allOf(hasType(INTEGER), hasValue((int) Math.ceil(value)))); - assertEquals(String.format("ceiling(%s)", value.toString()), ceiling.toString()); + assertEquals(String.format("ceiling(%s)", value), ceiling.toString()); } /** @@ -226,12 +226,12 @@ public void ceil_float_value(Float value) { FunctionExpression ceil = DSL.ceil(DSL.literal(value)); assertThat( ceil.valueOf(valueEnv()), allOf(hasType(INTEGER), hasValue((int) Math.ceil(value)))); - assertEquals(String.format("ceil(%s)", value.toString()), ceil.toString()); + assertEquals(String.format("ceil(%s)", value), ceil.toString()); FunctionExpression ceiling = DSL.ceiling(DSL.literal(value)); assertThat( ceiling.valueOf(valueEnv()), allOf(hasType(INTEGER), hasValue((int) Math.ceil(value)))); - assertEquals(String.format("ceiling(%s)", value.toString()), ceiling.toString()); + assertEquals(String.format("ceiling(%s)", value), ceiling.toString()); } /** @@ -243,12 +243,12 @@ public void ceil_double_value(Double value) { FunctionExpression ceil = DSL.ceil(DSL.literal(value)); assertThat( ceil.valueOf(valueEnv()), allOf(hasType(INTEGER), hasValue((int) Math.ceil(value)))); - assertEquals(String.format("ceil(%s)", value.toString()), ceil.toString()); + assertEquals(String.format("ceil(%s)", value), ceil.toString()); FunctionExpression ceiling = DSL.ceiling(DSL.literal(value)); assertThat( ceiling.valueOf(valueEnv()), allOf(hasType(INTEGER), hasValue((int) Math.ceil(value)))); - assertEquals(String.format("ceiling(%s)", value.toString()), ceiling.toString()); + assertEquals(String.format("ceiling(%s)", value), ceiling.toString()); } /** @@ -1721,12 +1721,13 @@ public void sqrt_missing_value() { * Test truncate with integer value. */ @ParameterizedTest(name = "truncate({0}, {1})") - @ValueSource(ints = {2, -2}) + @ValueSource(ints = {2, -2, Integer.MAX_VALUE, Integer.MIN_VALUE}) public void truncate_int_value(Integer value) { FunctionExpression truncate = DSL.truncate(DSL.literal(value), DSL.literal(1)); assertThat( truncate.valueOf(valueEnv()), allOf(hasType(LONG), - hasValue(new BigDecimal(value).setScale(1, RoundingMode.DOWN).longValue()))); + hasValue(BigDecimal.valueOf(value).setScale(1, + value > 0 ? RoundingMode.FLOOR : RoundingMode.CEILING).longValue()))); assertEquals(String.format("truncate(%s, 1)", value), truncate.toString()); } @@ -1734,12 +1735,13 @@ public void truncate_int_value(Integer value) { * Test truncate with long value. */ @ParameterizedTest(name = "truncate({0}, {1})") - @ValueSource(longs = {2L, -2L}) + @ValueSource(longs = {2L, -2L, Long.MAX_VALUE, Long.MIN_VALUE}) public void truncate_long_value(Long value) { FunctionExpression truncate = DSL.truncate(DSL.literal(value), DSL.literal(1)); assertThat( truncate.valueOf(valueEnv()), allOf(hasType(LONG), - hasValue(new BigDecimal(value).setScale(1, RoundingMode.DOWN).longValue()))); + hasValue(BigDecimal.valueOf(value).setScale(1, + value > 0 ? RoundingMode.FLOOR : RoundingMode.CEILING).longValue()))); assertEquals(String.format("truncate(%s, 1)", value), truncate.toString()); } @@ -1747,12 +1749,13 @@ public void truncate_long_value(Long value) { * Test truncate with float value. */ @ParameterizedTest(name = "truncate({0}, {1})") - @ValueSource(floats = {2F, -2F}) + @ValueSource(floats = {2F, -2F, Float.MAX_VALUE, Float.MIN_VALUE}) public void truncate_float_value(Float value) { FunctionExpression truncate = DSL.truncate(DSL.literal(value), DSL.literal(1)); assertThat( truncate.valueOf(valueEnv()), allOf(hasType(DOUBLE), - hasValue(new BigDecimal(value).setScale(1, RoundingMode.DOWN).doubleValue()))); + hasValue(BigDecimal.valueOf(value).setScale(1, + value > 0 ? RoundingMode.FLOOR : RoundingMode.CEILING).doubleValue()))); assertEquals(String.format("truncate(%s, 1)", value), truncate.toString()); } @@ -1760,12 +1763,16 @@ public void truncate_float_value(Float value) { * Test truncate with double value. */ @ParameterizedTest(name = "truncate({0}, {1})") - @ValueSource(doubles = {2D, -2D}) + @ValueSource(doubles = {2D, -9.223372036854776e+18D, -2147483649.0D, -2147483648.0D, + -32769.0D, -32768.0D, -34.84D, -2.0D, -1.2D, -1.0D, 0.0D, 1.0D, + 1.3D, 2.0D, 1004.3D, 32767.0D, 32768.0D, 2147483647.0D, 2147483648.0D, + 9.223372036854776e+18D, Double.MAX_VALUE, Double.MIN_VALUE}) public void truncate_double_value(Double value) { FunctionExpression truncate = DSL.truncate(DSL.literal(value), DSL.literal(1)); assertThat( truncate.valueOf(valueEnv()), allOf(hasType(DOUBLE), - hasValue(new BigDecimal(value).setScale(1, RoundingMode.DOWN).doubleValue()))); + hasValue(BigDecimal.valueOf(value).setScale(1, + value > 0 ? RoundingMode.FLOOR : RoundingMode.CEILING).doubleValue()))); assertEquals(String.format("truncate(%s, 1)", value), truncate.toString()); } diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java index efa16ba9d7..f2d1bb7d28 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/MathematicalFunctionIT.java @@ -142,6 +142,30 @@ public void testTruncate() throws IOException { result = executeQuery("select truncate(-56, -1)"); verifySchema(result, schema("truncate(-56, -1)", null, "long")); verifyDataRows(result, rows(-50)); + + result = executeQuery("select truncate(33.33344, -1)"); + verifySchema(result, schema("truncate(33.33344, -1)", null, "double")); + verifyDataRows(result, rows(30.0)); + + result = executeQuery("select truncate(33.33344, 2)"); + verifySchema(result, schema("truncate(33.33344, 2)", null, "double")); + verifyDataRows(result, rows(33.33)); + + result = executeQuery("select truncate(33.33344, 100)"); + verifySchema(result, schema("truncate(33.33344, 100)", null, "double")); + verifyDataRows(result, rows(33.33344)); + + result = executeQuery("select truncate(33.33344, 0)"); + verifySchema(result, schema("truncate(33.33344, 0)", null, "double")); + verifyDataRows(result, rows(33.0)); + + result = executeQuery("select truncate(33.33344, 4)"); + verifySchema(result, schema("truncate(33.33344, 4)", null, "double")); + verifyDataRows(result, rows(33.3334)); + + result = executeQuery(String.format("select truncate(%s, 6)", Math.PI)); + verifySchema(result, schema(String.format("truncate(%s, 6)", Math.PI), null, "double")); + verifyDataRows(result, rows(3.141592)); } @Test