From 47c20574f745072244ccacc77d3d6b0b6ef5f5c2 Mon Sep 17 00:00:00 2001 From: blue-lemon0104 <1362203478@qq.com> Date: Fri, 17 Apr 2026 13:05:30 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Ddecimal=E8=A7=A3=E6=9E=90?= =?UTF-8?q?=E9=94=99=E8=AF=AF=E9=97=AE=E9=A2=98=EF=BC=8C=E7=BB=9F=E4=B8=80?= =?UTF-8?q?=E4=BB=A5%.15g=E8=BE=93=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/encryptsql/decryptres.cpp | 53 ++++++++++----------- src/encryptsql/encryptstmt.cpp | 87 +++++++++++++++++++++++++--------- src/utils/utils.cpp | 31 +++++++----- src/utils/utils.h | 10 ++-- 4 files changed, 116 insertions(+), 65 deletions(-) diff --git a/src/encryptsql/decryptres.cpp b/src/encryptsql/decryptres.cpp index 296d8af..5c04781 100755 --- a/src/encryptsql/decryptres.cpp +++ b/src/encryptsql/decryptres.cpp @@ -104,15 +104,15 @@ void decryptResult(int numberAttr, int numTuples, pAttrDescs pattDescs, pTuples else ctype = TYPE_INTEGER; - void (*anyTypetoString)(const char *, char *const, void *) = NULL; // 把buffer转为相应类型的数据。 - switch (ctype) { - // case TYPE_FLOAT: - // anyTypetoString = &FloatTypetoString; - // break; - - case TYPE_STRING: - - anyTypetoString = NULL; //无需转换 + void (*anyTypetoString)(const char *, char *const, void *) = NULL; // 把buffer转为相应类型的数据。 + switch (ctype) { + case TYPE_FLOAT: + anyTypetoString = &DoubleTypetoString; + break; + + case TYPE_STRING: + + anyTypetoString = NULL; //无需转换 break; case TYPE_INTEGER: anyTypetoString = &IntTypetoString; @@ -166,15 +166,12 @@ void decryptResult(int numberAttr, int numTuples, pAttrDescs pattDescs, pTuples (char*)dek.c_str(), // 使用DEK (char*)aesOut.get() ); - // 确保输出大小正确 - if (outSize >= sizeof(double)) { - double d = *(double*)aesOut.get(); - std::ostringstream ss; - ss << d; - strcpy(tuples[i][j].value, ss.str().c_str()); - } - delete[] buff; - } else { + // 确保输出大小正确 + if (outSize >= sizeof(double)) { + DoubleTypetoString((const char*)aesOut.get(), tuples[i][j].value, (void *) coltypename.get()); + } + delete[] buff; + } else { size_t buffSZ = tuples[i][j].len / 2 + 1; // 密文buffer长度, 应该大于hexstr长度的一半 size_t plainBuffSZ = buffSZ; uint8_t *plainBuff = nullptr; @@ -211,15 +208,15 @@ void decryptResult(int numberAttr, int numTuples, pAttrDescs pattDescs, pTuples std::string dek; // cmk_mapperGetDekByCol(cipherColName,dek); DekInterface::getDekColLevel(cipherColName, dek); - plainBuff = encryptValue(type, pbuff, buffSZ, &plainBuffSZ, dek.c_str(), false); - //KMS test end; - // plainBuff = encryptValue(t, pbuff, buffSZ, &plainBuffSZ, false); - if (anyTypetoString) { - //fixed by qxy for RND - if (type == CIPHER_RND) { - long l = *(long *)plainBuff; - std::ostringstream ss; - ss << l; + plainBuff = encryptValue(type, pbuff, buffSZ, &plainBuffSZ, dek.c_str(), false); + //KMS test end; + // plainBuff = encryptValue(t, pbuff, buffSZ, &plainBuffSZ, false); + if (anyTypetoString) { + //fixed by qxy for RND + if (type == CIPHER_RND && ctype != TYPE_FLOAT) { + long l = *(long *)plainBuff; + std::ostringstream ss; + ss << l; strcpy(tuples[i][j].value, ss.str().c_str()); } else { anyTypetoString((const char *) plainBuff, tuples[i][j].value, (void *) coltypename.get()); @@ -264,4 +261,4 @@ void decryptResult(int numberAttr, int numTuples, pAttrDescs pattDescs, pTuples counter->count("SQL Decryption", timer.passedTimeMicroSecond()); cleanup(); -} \ No newline at end of file +} diff --git a/src/encryptsql/encryptstmt.cpp b/src/encryptsql/encryptstmt.cpp index e48c05a..eca008c 100755 --- a/src/encryptsql/encryptstmt.cpp +++ b/src/encryptsql/encryptstmt.cpp @@ -57,6 +57,24 @@ bool IsInt8ColumnType(const char* type_name) return type.find("int8") != std::string::npos || type.find("bigint") != std::string::npos; } +bool IsFloatColumnType(const char* type_name) +{ + return type_name && unifyColumnType(type_name) == TYPE_FLOAT; +} + +bool ShouldSetPeerFloat(const char* type_name, EncryptInfo* info) +{ + if (!IsFloatColumnType(type_name)) + return false; + if (!strcmp(type_name, "float")) + return true; + if (!info || !IsA((Node *) info->father, A_Expr)) + return false; + + const char* op = getAExprOp((A_Expr *) info->father); + return op && !strcmp(op, "="); +} + bool TryParseInt64Literal(const char* literal, int64_t* value) { if (!literal || !value) @@ -224,15 +242,29 @@ static A_Const *encryptAConst(A_Const *aconst, T_Cipher encryptCipher, EncryptIn } plainText = (uint8_t *) tmpInt; in_size = sizeof(int64_t); - } else if (encryptCipher == CIPHER_RND) { - *tmpInt = intVal(AConstValue); -// isFloat = true; -// *tmpInt *= Float_Scale; + } else if (encryptCipher == CIPHER_RND || encryptCipher == CIPHER_RNDSM4CK) { + if (isFloat || info->isPeerColFloat || IsA(AConstValue, Float)) { + isFloat = true; + if (IsA(AConstValue, Integer)) { + tmpDouble = static_cast(intVal(AConstValue)); + } else { + tmpDouble = atof(strVal(AConstValue)); + } + plainText = (uint8_t *) &tmpDouble; + in_size = sizeof(double); + } else { + *tmpInt = intVal(AConstValue); +// isFloat = true; +// *tmpInt *= Float_Scale; + plainText = (uint8_t *) tmpInt; + in_size = sizeof(int64_t); + } + if (info->isPeerColFloat) { + info->isPeerColFloat = false; + } if (info->isPeerColInt8) { info->isPeerColInt8 = false; } - plainText = (uint8_t *) tmpInt; - in_size = sizeof(int64_t); } else if (IsA(AConstValue, Float)) // { isFloat = true; @@ -257,17 +289,28 @@ static A_Const *encryptAConst(A_Const *aconst, T_Cipher encryptCipher, EncryptIn plainText = (uint8_t *) &tmpDouble; in_size = sizeof(double); } else if (IsA(AConstValue, Integer)) { - *tmpInt = intVal(AConstValue); - // isFloat = true; - isFloat = false; - // *tmpInt *= Float_Scale; + if (isFloat || info->isPeerColFloat) { + isFloat = true; + tmpDouble = static_cast(intVal(AConstValue)); + plainText = (uint8_t *) &tmpDouble; + in_size = sizeof(double); + if (info->isPeerColFloat) { + info->isPeerColFloat = false; + } + } else { + *tmpInt = intVal(AConstValue); + // isFloat = true; + isFloat = false; + // *tmpInt *= Float_Scale; + plainText = (uint8_t *) tmpInt; + in_size = sizeof(int64_t); + } if (info->isPeerColInt8) { info->isPeerColInt8 = false; } - plainText = (uint8_t *) tmpInt; - in_size = sizeof(int64_t); } else { *tmpInt = intVal(AConstValue); + tmpDouble = static_cast(*tmpInt); // if (info->isFloatorIntCol) { // 当前列是float或Int列 // isFloat = true; // *tmpInt *= Float_Scale; @@ -611,14 +654,14 @@ static ColumnRef *encryptColumnRef(ColumnRef *cref, T_Cipher encryptCipher, Encr } cipherColNameValue = makeString((char *) cipherColName); newCref->fields = lappend(newCref->fields, cipherColNameValue); - p = lnext(p); - char typebuf[128]; - if (IsA((Node *) info->father, A_Expr)) { - getColumnType(name, typebuf); - if (!strcmp(typebuf, "float")) { - info->isPeerColFloat = true; - } - } + p = lnext(p); + char typebuf[128]; + if (IsA((Node *) info->father, A_Expr)) { + getColumnType(name, typebuf); + if (ShouldSetPeerFloat(typebuf, info)) { + info->isPeerColFloat = true; + } + } } return newCref; } @@ -1581,8 +1624,8 @@ static List *encryptValuesLists(List *valuesLists, List *cols, EncryptInfo *info // info->isFloatorIntCol = true; // } else { // info->isFloatorIntCol = false; - // } - if (string(colInfo->type).find("float") != string::npos) + // } + if (IsFloatColumnType(colInfo->type)) { info->isFloatCol = true; } diff --git a/src/utils/utils.cpp b/src/utils/utils.cpp index 0a62d7c..c14968c 100755 --- a/src/utils/utils.cpp +++ b/src/utils/utils.cpp @@ -429,10 +429,10 @@ void IntTypetoString(const char *buf, char *const v, void *others) { strcpy(v, ss.str().c_str()); } -void FloatTypetoString(const char *buf, char *v, - void *others) // TODO: // buf 不一定是 double的内存排布,还有可能是long, 因为词法分析不区分 (long)1, (double)1 -{ - long l = *(long *) buf; +void FloatTypetoString(const char *buf, char *v, + void *others) // TODO: // buf 不一定是 double的内存排布,还有可能是long, 因为词法分析不区分 (long)1, (double)1 +{ + long l = *(long *) buf; double f = (l / Float_Scale) * 1.0; sprintf(v, "%.10f", f); size_t sz = strlen(v); @@ -454,13 +454,22 @@ void FloatTypetoString(const char *buf, char *v, break; } *c = '\0'; - - } -} - -void FloatTypetoString2(const char *buf, char *v, - void *others) // TODO: // buf 不一定是 double的内存排布,还有可能是long, 因为词法分析不区分 (long)1, (double)1 -{ + + } +} + +void DoubleTypetoString(const char *buf, char *v, void *others) +{ + double d = 0.0; + memcpy(&d, buf, sizeof(double)); + sprintf(v, "%.15g", d); + if (!strcmp(v, "-0")) + strcpy(v, "0"); +} + +void FloatTypetoString2(const char *buf, char *v, + void *others) // TODO: // buf 不一定是 double的内存排布,还有可能是long, 因为词法分析不区分 (long)1, (double)1 +{ long l = *(long *) buf; double f = l * 1.0; sprintf(v, "%.10f", f); diff --git a/src/utils/utils.h b/src/utils/utils.h index ead9ee0..b805e3b 100755 --- a/src/utils/utils.h +++ b/src/utils/utils.h @@ -151,9 +151,11 @@ COLUMN_TYPE unifyColumnType(const std::string &s); void IntTypetoString(const char *buf, char *const v, void *others); -void FloatTypetoString(const char *buf, char *const v, void *others); - -void FloatTypetoString2(const char *buf, char *const v, void *others); +void FloatTypetoString(const char *buf, char *const v, void *others); + +void DoubleTypetoString(const char *buf, char *const v, void *others); + +void FloatTypetoString2(const char *buf, char *const v, void *others); void FloatTypetoString3(const char *buf, char *const v, void *others); @@ -174,4 +176,4 @@ long GetCurrentTimestamp(void); void fmttime(time_t lt1, char *res); -uint64_t mygettid(); \ No newline at end of file +uint64_t mygettid();