diff --git a/include/encryptsql.h b/include/encryptsql.h index 59632ca..d3bbba7 100755 --- a/include/encryptsql.h +++ b/include/encryptsql.h @@ -7,14 +7,16 @@ extern "C" const char *encryptOneSql(const char* sql, char** err_msg, const char extern "C" void decryptResult(int numberAttr, int numTuples, pAttrDescs pattDescs, pTuples ptuples); -struct EncryptInfo { - const char *sql; // 正在处理的sql - bool isFloatCol; //当前处理的col是否是float - // bool isFloatorIntCol; //当前处理的col是否是float或int - void *father; - bool isPeerColFloat; // where col_float = 10; 这种where条件中,10的AES需要放缩, isPeerColFloat为true表示在一个二元操作符中,操作数为float column. - bool isALeftOps; - bool isARightOps; - bool isFromAExpr = false; - bool isFromUpdate = false; -}; +struct EncryptInfo { + const char *sql; // 正在处理的sql + bool isFloatCol; //当前处理的col是否是float + bool isInt8Col; // 当前处理的col是否是int8/bigint + // bool isFloatorIntCol; //当前处理的col是否是float或int + void *father; + bool isPeerColFloat; // where col_float = 10; 这种where条件中,10的AES需要放缩, isPeerColFloat为true表示在一个二元操作符中,操作数为float column. + bool isPeerColInt8; // where col_int8 = 9999999999; 这种where条件中,大整数字面量可能被解析成Float节点,需要按int8处理. + bool isALeftOps; + bool isARightOps; + bool isFromAExpr = false; + bool isFromUpdate = false; +}; diff --git a/src/encryptsql/encryptstmt.cpp b/src/encryptsql/encryptstmt.cpp index c116c11..e48c05a 100755 --- a/src/encryptsql/encryptstmt.cpp +++ b/src/encryptsql/encryptstmt.cpp @@ -27,6 +27,8 @@ extern "C" #include "encryptsql/fieldmap.h" #include +#include +#include #include #include #include @@ -46,6 +48,30 @@ extern "C" #include "KeyDistribution/non_enc_client/client_interface.h" namespace { +bool IsInt8ColumnType(const char* type_name) +{ + if (!type_name) + return false; + + std::string type(type_name); + return type.find("int8") != std::string::npos || type.find("bigint") != std::string::npos; +} + +bool TryParseInt64Literal(const char* literal, int64_t* value) +{ + if (!literal || !value) + return false; + + errno = 0; + char* endptr = nullptr; + long long parsed = std::strtoll(literal, &endptr, 10); + if (errno != 0 || endptr == literal || (endptr && *endptr != '\0')) + return false; + + *value = static_cast(parsed); + return true; +} + bool SendDekViaTls(const std::string& dek_plain) { kd::client::ClientConfig cfg; @@ -189,14 +215,25 @@ static A_Const *encryptAConst(A_Const *aconst, T_Cipher encryptCipher, EncryptIn } else { auto tmpInt = (int64_t *) palloc(sizeof(int64_t)); double tmpDouble; + bool isTargetInt8 = info->isInt8Col || info->isPeerColInt8; // 去除放缩逻辑 - if (encryptCipher == CIPHER_RND) { + if (IsA(AConstValue, Float) && isTargetInt8 && TryParseInt64Literal(strVal(AConstValue), tmpInt)) { + isFloat = false; + if (info->isPeerColInt8) { + info->isPeerColInt8 = false; + } + plainText = (uint8_t *) tmpInt; + in_size = sizeof(int64_t); + } else if (encryptCipher == CIPHER_RND) { *tmpInt = intVal(AConstValue); // isFloat = true; // *tmpInt *= Float_Scale; + if (info->isPeerColInt8) { + info->isPeerColInt8 = false; + } plainText = (uint8_t *) tmpInt; in_size = sizeof(int64_t); - } else if (IsA(AConstValue, Float) || (info->isPeerColFloat)) // + } else if (IsA(AConstValue, Float)) // { isFloat = true; tmpDouble = atof(strVal(AConstValue)); @@ -204,6 +241,19 @@ static A_Const *encryptAConst(A_Const *aconst, T_Cipher encryptCipher, EncryptIn if (info->isPeerColFloat) { info->isPeerColFloat = false; } + if (info->isPeerColInt8) { + info->isPeerColInt8 = false; + } + plainText = (uint8_t *) &tmpDouble; + in_size = sizeof(double); + } else if (info->isPeerColFloat) { + isFloat = true; + if (IsA(AConstValue, Integer)) { + tmpDouble = static_cast(intVal(AConstValue)); + } else { + tmpDouble = atof(strVal(AConstValue)); + } + info->isPeerColFloat = false; plainText = (uint8_t *) &tmpDouble; in_size = sizeof(double); } else if (IsA(AConstValue, Integer)) { @@ -211,6 +261,9 @@ static A_Const *encryptAConst(A_Const *aconst, T_Cipher encryptCipher, EncryptIn // isFloat = true; isFloat = false; // *tmpInt *= Float_Scale; + if (info->isPeerColInt8) { + info->isPeerColInt8 = false; + } plainText = (uint8_t *) tmpInt; in_size = sizeof(int64_t); } else { @@ -443,9 +496,9 @@ static List *encryptStar(EncryptInfo *info) { // 处理from表的第一个匹配 t, patchedName, name); if (strcmp(name, cipherColName) == 0) { // 这个字段没有encryptCipher, 可能是个NOCRYPT加密 q:判断是否加密,如果没加密复制原列名,如果加密了则报错 //根据列的加密数来判断该列是否加密 - if (ncipher == 0) { - cipherColName = name; - } else { + if (ncipher == 0) { + cipherColName = name; + } else { string err_msg = name; err_msg = err_msg + ": You are using a feature not assigned to this column. Please check the features added when the column was created."; @@ -460,13 +513,15 @@ static List *encryptStar(EncryptInfo *info) { // 处理from表的第一个匹配 newCref->fields = lappend(newCref->fields, cipherColNameValue); p = lnext(p); char typebuf[128]; - if (IsA((Node *) info->father, A_Expr)) { - getColumnType(name, typebuf); - if (!strcmp(typebuf, "float")) { - info->isPeerColFloat = true; - } - } - } + if (IsA((Node *) info->father, A_Expr) && ncipher > 0) { + getColumnType(name, typebuf); + if (!strcmp(typebuf, "float")) { + info->isPeerColFloat = true; + } else if (IsInt8ColumnType(typebuf)) { + info->isPeerColInt8 = true; + } + } + } return newCref; }*/ @@ -640,7 +695,7 @@ static Node *encryptAExpr(A_Expr *expr, EncryptInfo *info) // 将表达式转为 //colname 在加密之前需要跳过colname获取到列密钥 char *encRname = NULL,*encLname = NULL; char *patchedName = NULL; - if (nodeTag(expr->rexpr) == T_ColumnRef) { + if (nodeTag(expr->rexpr) == T_ColumnRef) { is_rexpr_column = true; auto tmp1 = (ColumnRef *) expr->rexpr; ListCell *tmp2 = list_head(tmp1->fields); @@ -649,8 +704,8 @@ static Node *encryptAExpr(A_Expr *expr, EncryptInfo *info) // 将表达式转为 T_Cipher ciphers[CIPHER_COUNT]; getColumnCiphers(rexpr_name, ciphers, &ncipher); - T_Cipher ctype=CIPHER_NOCRYPT; - for(auto cipher : ciphers) { + T_Cipher ctype=CIPHER_NOCRYPT; + for(auto cipher : ciphers) { if(cipher == CIPHER_AES){ ctype = cipher; break; @@ -663,17 +718,25 @@ static Node *encryptAExpr(A_Expr *expr, EncryptInfo *info) // 将表达式转为 }else if(cipher == CIPHER_AESHMAC){ ctype = cipher; break; + } } - } - if (ncipher == 0) { - is_rexpr_plaintext = true; - } else if(strcmp(op,"=") == 0){ - t1 = t2 = ctype; - if(nodeTag(expr->lexpr) == T_A_Const){ - patchedName = addEncryptSubfix(ctype, rexpr_name); - encRname = getMappedName(T_STRING_COLUMN, patchedName, rexpr_name); - // cmk_mapperSetInfoCol(encRname); + if (nodeTag(expr->lexpr) == T_A_Const && ncipher > 0) { + char typebuf[128]; + getColumnType(rexpr_name, typebuf); + if (IsInt8ColumnType(typebuf)) { + info->isPeerColInt8 = true; + } + } + + if (ncipher == 0) { + is_rexpr_plaintext = true; + } else if(strcmp(op,"=") == 0){ + t1 = t2 = ctype; + if(nodeTag(expr->lexpr) == T_A_Const){ + patchedName = addEncryptSubfix(ctype, rexpr_name); + encRname = getMappedName(T_STRING_COLUMN, patchedName, rexpr_name); + // cmk_mapperSetInfoCol(encRname); DekInterface::setInfoCol(encRname); info->isFromAExpr = true; } @@ -693,8 +756,8 @@ static Node *encryptAExpr(A_Expr *expr, EncryptInfo *info) // 将表达式转为 T_Cipher ciphers[CIPHER_COUNT]; getColumnCiphers(lexpr_name, ciphers, &ncipher); - T_Cipher ctype=CIPHER_NOCRYPT; - for(auto cipher : ciphers) { + T_Cipher ctype=CIPHER_NOCRYPT; + for(auto cipher : ciphers) { if(cipher == CIPHER_AESHMAC){ ctype = cipher; break; @@ -707,17 +770,25 @@ static Node *encryptAExpr(A_Expr *expr, EncryptInfo *info) // 将表达式转为 }else if(cipher == CIPHER_AES){ ctype = CIPHER_AES; break; + } } - } - if (ncipher == 0) { - is_lexpr_plaintext = true; - } else if(strcmp(op,"=") == 0){ - t1 = t2 = ctype; - if(nodeTag(expr->rexpr) == T_A_Const){ - patchedName = addEncryptSubfix(ctype, lexpr_name); - encLname = getMappedName(T_STRING_COLUMN, patchedName, lexpr_name); - // cmk_mapperSetInfoCol(encLname); + if (nodeTag(expr->rexpr) == T_A_Const && ncipher > 0) { + char typebuf[128]; + getColumnType(lexpr_name, typebuf); + if (IsInt8ColumnType(typebuf)) { + info->isPeerColInt8 = true; + } + } + + if (ncipher == 0) { + is_lexpr_plaintext = true; + } else if(strcmp(op,"=") == 0){ + t1 = t2 = ctype; + if(nodeTag(expr->rexpr) == T_A_Const){ + patchedName = addEncryptSubfix(ctype, lexpr_name); + encLname = getMappedName(T_STRING_COLUMN, patchedName, lexpr_name); + // cmk_mapperSetInfoCol(encLname); DekInterface::setInfoCol(encLname); info->isFromAExpr = true; } @@ -1511,15 +1582,23 @@ static List *encryptValuesLists(List *valuesLists, List *cols, EncryptInfo *info // } else { // info->isFloatorIntCol = false; // } - if (string(colInfo->type).find("float") != string::npos) - { - info->isFloatCol = true; - } - else - { - info->isFloatCol = false; - } - if (IsA(n, FuncCall)) { + if (string(colInfo->type).find("float") != string::npos) + { + info->isFloatCol = true; + } + else + { + info->isFloatCol = false; + } + if (IsInt8ColumnType(colInfo->type)) + { + info->isInt8Col = true; + } + else + { + info->isInt8Col = false; + } + if (IsA(n, FuncCall)) { Node *tmpRes = dealWithSpecialFunction((FuncCall *) n, info); n = tmpRes ? tmpRes : n; }