修复Int8类型,当数据超过int4范围时,语法树解析为float导致解密错误

This commit is contained in:
blue-lemon0104
2026-04-17 11:36:22 +08:00
parent 37595bd51b
commit 675146e477
2 changed files with 136 additions and 55 deletions

View File

@@ -10,9 +10,11 @@ extern "C" void decryptResult(int numberAttr, int numTuples, pAttrDescs pattDesc
struct EncryptInfo {
const char *sql; // 正在处理的sql
bool isFloatCol; //当前处理的col是否是float
bool isInt8Col; // 当前处理的col是否是int8/bigint
// bool isFloatorIntCol; //当前处理的col是否是float或int
void *father;
bool isPeerColFloat; // where col_float = 10; 这种where条件中10的AES需要放缩 isPeerColFloat为true表示在一个二元操作符中操作数为float column.
bool isPeerColInt8; // where col_int8 = 9999999999; 这种where条件中大整数字面量可能被解析成Float节点需要按int8处理.
bool isALeftOps;
bool isARightOps;
bool isFromAExpr = false;

View File

@@ -27,6 +27,8 @@ extern "C"
#include "encryptsql/fieldmap.h"
#include <assert.h>
#include <cerrno>
#include <cstdlib>
#include <stdexcept>
#include <set>
#include <vector>
@@ -46,6 +48,30 @@ extern "C"
#include "KeyDistribution/non_enc_client/client_interface.h"
namespace {
bool IsInt8ColumnType(const char* type_name)
{
if (!type_name)
return false;
std::string type(type_name);
return type.find("int8") != std::string::npos || type.find("bigint") != std::string::npos;
}
bool TryParseInt64Literal(const char* literal, int64_t* value)
{
if (!literal || !value)
return false;
errno = 0;
char* endptr = nullptr;
long long parsed = std::strtoll(literal, &endptr, 10);
if (errno != 0 || endptr == literal || (endptr && *endptr != '\0'))
return false;
*value = static_cast<int64_t>(parsed);
return true;
}
bool SendDekViaTls(const std::string& dek_plain)
{
kd::client::ClientConfig cfg;
@@ -189,14 +215,25 @@ static A_Const *encryptAConst(A_Const *aconst, T_Cipher encryptCipher, EncryptIn
} else {
auto tmpInt = (int64_t *) palloc(sizeof(int64_t));
double tmpDouble;
bool isTargetInt8 = info->isInt8Col || info->isPeerColInt8;
// 去除放缩逻辑
if (encryptCipher == CIPHER_RND) {
if (IsA(AConstValue, Float) && isTargetInt8 && TryParseInt64Literal(strVal(AConstValue), tmpInt)) {
isFloat = false;
if (info->isPeerColInt8) {
info->isPeerColInt8 = false;
}
plainText = (uint8_t *) tmpInt;
in_size = sizeof(int64_t);
} else if (encryptCipher == CIPHER_RND) {
*tmpInt = intVal(AConstValue);
// isFloat = true;
// *tmpInt *= Float_Scale;
if (info->isPeerColInt8) {
info->isPeerColInt8 = false;
}
plainText = (uint8_t *) tmpInt;
in_size = sizeof(int64_t);
} else if (IsA(AConstValue, Float) || (info->isPeerColFloat)) //
} else if (IsA(AConstValue, Float)) //
{
isFloat = true;
tmpDouble = atof(strVal(AConstValue));
@@ -204,6 +241,19 @@ static A_Const *encryptAConst(A_Const *aconst, T_Cipher encryptCipher, EncryptIn
if (info->isPeerColFloat) {
info->isPeerColFloat = false;
}
if (info->isPeerColInt8) {
info->isPeerColInt8 = false;
}
plainText = (uint8_t *) &tmpDouble;
in_size = sizeof(double);
} else if (info->isPeerColFloat) {
isFloat = true;
if (IsA(AConstValue, Integer)) {
tmpDouble = static_cast<double>(intVal(AConstValue));
} else {
tmpDouble = atof(strVal(AConstValue));
}
info->isPeerColFloat = false;
plainText = (uint8_t *) &tmpDouble;
in_size = sizeof(double);
} else if (IsA(AConstValue, Integer)) {
@@ -211,6 +261,9 @@ static A_Const *encryptAConst(A_Const *aconst, T_Cipher encryptCipher, EncryptIn
// isFloat = true;
isFloat = false;
// *tmpInt *= Float_Scale;
if (info->isPeerColInt8) {
info->isPeerColInt8 = false;
}
plainText = (uint8_t *) tmpInt;
in_size = sizeof(int64_t);
} else {
@@ -460,10 +513,12 @@ static List *encryptStar(EncryptInfo *info) { // 处理from表的第一个匹配
newCref->fields = lappend(newCref->fields, cipherColNameValue);
p = lnext(p);
char typebuf[128];
if (IsA((Node *) info->father, A_Expr)) {
if (IsA((Node *) info->father, A_Expr) && ncipher > 0) {
getColumnType(name, typebuf);
if (!strcmp(typebuf, "float")) {
info->isPeerColFloat = true;
} else if (IsInt8ColumnType(typebuf)) {
info->isPeerColInt8 = true;
}
}
}
@@ -666,6 +721,14 @@ static Node *encryptAExpr(A_Expr *expr, EncryptInfo *info) // 将表达式转为
}
}
if (nodeTag(expr->lexpr) == T_A_Const && ncipher > 0) {
char typebuf[128];
getColumnType(rexpr_name, typebuf);
if (IsInt8ColumnType(typebuf)) {
info->isPeerColInt8 = true;
}
}
if (ncipher == 0) {
is_rexpr_plaintext = true;
} else if(strcmp(op,"=") == 0){
@@ -710,6 +773,14 @@ static Node *encryptAExpr(A_Expr *expr, EncryptInfo *info) // 将表达式转为
}
}
if (nodeTag(expr->rexpr) == T_A_Const && ncipher > 0) {
char typebuf[128];
getColumnType(lexpr_name, typebuf);
if (IsInt8ColumnType(typebuf)) {
info->isPeerColInt8 = true;
}
}
if (ncipher == 0) {
is_lexpr_plaintext = true;
} else if(strcmp(op,"=") == 0){
@@ -1519,6 +1590,14 @@ static List *encryptValuesLists(List *valuesLists, List *cols, EncryptInfo *info
{
info->isFloatCol = false;
}
if (IsInt8ColumnType(colInfo->type))
{
info->isInt8Col = true;
}
else
{
info->isInt8Col = false;
}
if (IsA(n, FuncCall)) {
Node *tmpRes = dealWithSpecialFunction((FuncCall *) n, info);
n = tmpRes ? tmpRes : n;