Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[branch-2.1](function) Refine crypto functions signature to fix wrong result(#40285) #40648

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -640,11 +640,7 @@ private String paramsToSql() {
&& (fnName.getFunction().equalsIgnoreCase("aes_decrypt")
|| fnName.getFunction().equalsIgnoreCase("aes_encrypt")
|| fnName.getFunction().equalsIgnoreCase("sm4_decrypt")
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt")
|| fnName.getFunction().equalsIgnoreCase("aes_decrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("aes_encrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("sm4_decrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt_v2"))) {
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt"))) {
sb.append("\'***\'");
} else if (orderByElements.size() > 0 && i == len - orderByElements.size()) {
sb.append("ORDER BY ");
Expand Down Expand Up @@ -718,22 +714,13 @@ private String paramsToDigest() {
if (fnName.getFunction().equalsIgnoreCase("aes_decrypt")
|| fnName.getFunction().equalsIgnoreCase("aes_encrypt")
|| fnName.getFunction().equalsIgnoreCase("sm4_decrypt")
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt")
|| fnName.getFunction().equalsIgnoreCase("aes_decrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("aes_encrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("sm4_decrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt_v2")) {
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt")) {
len = len - 1;
}
for (int i = 0; i < len; ++i) {
if (i == 1 && (fnName.getFunction().equalsIgnoreCase("aes_decrypt")
|| fnName.getFunction().equalsIgnoreCase("aes_encrypt")
|| fnName.getFunction().equalsIgnoreCase("sm4_decrypt")
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt")
|| fnName.getFunction().equalsIgnoreCase("aes_decrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("aes_encrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("sm4_decrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt_v2"))) {
|| fnName.getFunction().equalsIgnoreCase("sm4_decrypt"))) {
result.add("\'***\'");
} else {
result.add(children.get(i).toDigest());
Expand Down Expand Up @@ -1141,13 +1128,8 @@ private void analyzeBuiltinAggFunction(Analyzer analyzer) throws AnalysisExcepti
if ((fnName.getFunction().equalsIgnoreCase("aes_decrypt")
|| fnName.getFunction().equalsIgnoreCase("aes_encrypt")
|| fnName.getFunction().equalsIgnoreCase("sm4_decrypt")
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt")
|| fnName.getFunction().equalsIgnoreCase("aes_decrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("aes_encrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("sm4_decrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt_v2"))
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt"))
&& (children.size() == 2 || children.size() == 3)) {
String blockEncryptionMode = "";
Set<String> aesModes = new HashSet<>(Arrays.asList(
"AES_128_ECB",
"AES_192_ECB",
Expand Down Expand Up @@ -1181,80 +1163,33 @@ private void analyzeBuiltinAggFunction(Analyzer analyzer) throws AnalysisExcepti
"SM4_128_OFB",
"SM4_128_CTR"));

String blockEncryptionMode = "";
if (ConnectContext.get() != null) {
blockEncryptionMode = ConnectContext.get().getSessionVariable().getBlockEncryptionMode();
if (fnName.getFunction().equalsIgnoreCase("aes_decrypt")
|| fnName.getFunction().equalsIgnoreCase("aes_encrypt")
|| fnName.getFunction().equalsIgnoreCase("aes_decrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("aes_encrypt_v2")) {
|| fnName.getFunction().equalsIgnoreCase("aes_encrypt")) {
if (StringUtils.isAllBlank(blockEncryptionMode)) {
blockEncryptionMode = "AES_128_ECB";
}
if (!aesModes.contains(blockEncryptionMode.toUpperCase())) {
throw new AnalysisException("session variable block_encryption_mode is invalid with aes");
}
if (children.size() == 2) {
boolean isECB = blockEncryptionMode.equalsIgnoreCase("AES_128_ECB")
|| blockEncryptionMode.equalsIgnoreCase("AES_192_ECB")
|| blockEncryptionMode.equalsIgnoreCase("AES_256_ECB");
if (fnName.getFunction().equalsIgnoreCase("aes_decrypt_v2")) {
if (!isECB) {
throw new AnalysisException(
"Incorrect parameter count in the call to native function 'aes_decrypt'");
}
} else if (fnName.getFunction().equalsIgnoreCase("aes_encrypt_v2")) {
if (!isECB) {
throw new AnalysisException(
"Incorrect parameter count in the call to native function 'aes_encrypt'");
}
} else {
// if there are only 2 params, we need set encryption mode to AES_128_ECB
// this keeps the behavior consistent with old doris ver.
blockEncryptionMode = "AES_128_ECB";
}
}
}
if (fnName.getFunction().equalsIgnoreCase("sm4_decrypt")
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt")
|| fnName.getFunction().equalsIgnoreCase("sm4_decrypt_v2")
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt_v2")) {
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt")) {
if (StringUtils.isAllBlank(blockEncryptionMode)) {
blockEncryptionMode = "SM4_128_ECB";
}
if (!sm4Modes.contains(blockEncryptionMode.toUpperCase())) {
throw new AnalysisException(
"session variable block_encryption_mode is invalid with sm4");
}
if (children.size() == 2) {
if (fnName.getFunction().equalsIgnoreCase("sm4_decrypt_v2")) {
throw new AnalysisException(
"Incorrect parameter count in the call to native function 'sm4_decrypt'");
} else if (fnName.getFunction().equalsIgnoreCase("sm4_encrypt_v2")) {
throw new AnalysisException(
"Incorrect parameter count in the call to native function 'sm4_encrypt'");
} else {
// if there are only 2 params, we need add an empty string as the third param
// and set encryption mode to SM4_128_ECB
// this keeps the behavior consistent with old doris ver.
children.add(new StringLiteral(""));
blockEncryptionMode = "SM4_128_ECB";
}
}
}
} else {
throw new AnalysisException("cannot get session variable `block_encryption_mode`, "
+ "please explicitly specify by using 4-args function");
}
if (!blockEncryptionMode.equals(children.get(children.size() - 1).toString())) {
children.add(new StringLiteral(blockEncryptionMode));
}

if (fnName.getFunction().equalsIgnoreCase("aes_decrypt_v2")) {
fnName = FunctionName.createBuiltinName("aes_decrypt");
} else if (fnName.getFunction().equalsIgnoreCase("aes_encrypt_v2")) {
fnName = FunctionName.createBuiltinName("aes_encrypt");
} else if (fnName.getFunction().equalsIgnoreCase("sm4_decrypt_v2")) {
fnName = FunctionName.createBuiltinName("sm4_decrypt");
} else if (fnName.getFunction().equalsIgnoreCase("sm4_encrypt_v2")) {
fnName = FunctionName.createBuiltinName("sm4_encrypt");
}
children.add(new StringLiteral(blockEncryptionMode));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.Abs;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Acos;
import org.apache.doris.nereids.trees.expressions.functions.scalar.AesDecrypt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.AesDecryptV2;
import org.apache.doris.nereids.trees.expressions.functions.scalar.AesEncrypt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.AesEncryptV2;
import org.apache.doris.nereids.trees.expressions.functions.scalar.AppendTrailingCharIfAbsent;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Array;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayApply;
Expand Down Expand Up @@ -358,9 +356,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.Sm3;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Sm3sum;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Sm4Decrypt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Sm4DecryptV2;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Sm4Encrypt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Sm4EncryptV2;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Space;
import org.apache.doris.nereids.trees.expressions.functions.scalar.SplitByChar;
import org.apache.doris.nereids.trees.expressions.functions.scalar.SplitByString;
Expand Down Expand Up @@ -465,9 +461,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(Abs.class, "abs"),
scalar(Acos.class, "acos"),
scalar(AesDecrypt.class, "aes_decrypt"),
scalar(AesDecryptV2.class, "aes_decrypt_v2"),
scalar(AesEncrypt.class, "aes_encrypt"),
scalar(AesEncryptV2.class, "aes_encrypt_v2"),
scalar(AppendTrailingCharIfAbsent.class, "append_trailing_char_if_absent"),
scalar(Array.class, "array"),
scalar(ArrayApply.class, "array_apply"),
Expand Down Expand Up @@ -823,9 +817,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(Sm3.class, "sm3"),
scalar(Sm3sum.class, "sm3sum"),
scalar(Sm4Decrypt.class, "sm4_decrypt"),
scalar(Sm4DecryptV2.class, "sm4_decrypt_v2"),
scalar(Sm4Encrypt.class, "sm4_encrypt"),
scalar(Sm4EncryptV2.class, "sm4_encrypt_v2"),
scalar(Space.class, "space"),
scalar(SplitByChar.class, "split_by_char"),
scalar(SplitByString.class, "split_by_string"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.doris.nereids.trees.expressions.functions.scalar;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
Expand Down Expand Up @@ -58,16 +57,7 @@ public class AesDecrypt extends AesCryptoFunction {
* AesDecrypt
*/
public AesDecrypt(Expression arg0, Expression arg1) {
// if there are only 2 params, we need set encryption mode to AES_128_ECB
// this keeps the behavior consistent with old doris ver.
super("aes_decrypt", arg0, arg1, new StringLiteral("AES_128_ECB"));

// check if encryptionMode from session variables is valid
StringLiteral encryptionMode = CryptoFunction.getDefaultBlockEncryptionMode("AES_128_ECB");
if (!AES_MODES.contains(encryptionMode.getValue())) {
throw new AnalysisException(
"session variable block_encryption_mode is invalid with aes");
}
super("aes_decrypt", arg0, arg1, new StringLiteral(""), getDefaultBlockEncryptionMode());
}

public AesDecrypt(Expression arg0, Expression arg1, Expression arg2) {
Expand All @@ -89,7 +79,7 @@ public AesDecrypt withChildren(List<Expression> children) {
} else if (children().size() == 3) {
return new AesDecrypt(children.get(0), children.get(1), children.get(2));
} else {
return new AesDecrypt(children.get(0), children.get(1), children.get(2), (StringLiteral) children.get(3));
return new AesDecrypt(children.get(0), children.get(1), children.get(2), children.get(3));
}
}

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.doris.nereids.trees.expressions.functions.scalar;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
Expand Down Expand Up @@ -58,16 +57,7 @@ public class AesEncrypt extends AesCryptoFunction {
* Some javadoc for checkstyle...
*/
public AesEncrypt(Expression arg0, Expression arg1) {
// if there are only 2 params, we need set encryption mode to AES_128_ECB
// this keeps the behavior consistent with old doris ver.
super("aes_encrypt", arg0, arg1, new StringLiteral("AES_128_ECB"));

// check if encryptionMode from session variables is valid
StringLiteral encryptionMode = CryptoFunction.getDefaultBlockEncryptionMode("AES_128_ECB");
if (!AES_MODES.contains(encryptionMode.getValue())) {
throw new AnalysisException(
"session variable block_encryption_mode is invalid with aes");
}
super("aes_encrypt", arg0, arg1, new StringLiteral(""), getDefaultBlockEncryptionMode());
}

public AesEncrypt(Expression arg0, Expression arg1, Expression arg2) {
Expand All @@ -89,7 +79,7 @@ public AesEncrypt withChildren(List<Expression> children) {
} else if (children().size() == 3) {
return new AesEncrypt(children.get(0), children.get(1), children.get(2));
} else {
return new AesEncrypt(children.get(0), children.get(1), children.get(2), (StringLiteral) children.get(3));
return new AesEncrypt(children.get(0), children.get(1), children.get(2), children.get(3));
}
}

Expand Down
Loading
Loading