Files
encryptsql/src/KMSAdapter/dek_interface.cpp
blue-lemon0104 46fa58f6f8 merge
2026-04-07 15:45:41 +08:00

1000 lines
35 KiB
C++
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#include "kms/kms_interface.hpp"
#include "kmsAdapter/dek_interface.hpp"
#include <fstream> // 必须包含这个头文件
namespace DekInterface{
RotateCommandResult parseRotateCommand(const std::string& command,const std::string &user_name,const std::string &db_name) {
RotateCommandResult result;
result.valid = false;
std::string current_cmk;
DekInterface::setInfoUser(user_name);
DekInterface::setInfoDb(db_name);
// 创建命令的大写副本用于不区分大小写的匹配,但保留原始命令用于提取名称
std::string upperCommand = command;
for (char& c : upperCommand) {
c = toupper(c);
}
// 新的命令格式正则表达式
std::regex tkPattern(R"(^\s*ROTATE\s+DEK\s+NOW\s+`?(\w+)`?\s+-TK\s*;?\s*$)", std::regex::icase);
std::regex ckPattern(R"(^\s*ROTATE\s+DEK\s+NOW\s+`?(\w+)`?\s+-CK\s+([\w\s,]+)\s*;?\s*$)", std::regex::icase);
std::regex allPattern(R"(^\s*ROTATE\s+DEK\s+NOW\s+`?(\w+)`?\s+-ALL\s*;?\s*$)", std::regex::icase);
std::smatch matches;
// 匹配表密钥轮换 (--TK)
/**
* 只轮换表密钥
*/
if (std::regex_search(command, matches, tkPattern)) {
result.valid = true;
result.type = ROTATE_TABLE;
result.tableName = matches[1].str(); // 保留原始大小写
return result;
}
// 匹配列密钥轮换 (--CK)
/**
* 轮换选中列密钥
*/
if (std::regex_search(command, matches, ckPattern)) {
result.valid = true;
result.type = ROTATE_COLUMNS;
result.tableName = matches[1].str(); // 保留原始大小写
// 解析列名列表
std::string columnsStr = matches[2].str(); // 保留原始大小写
// 移除所有空格
std::string cleanColumnsStr;
for (char c : columnsStr) {
if (c != ' ' && c != '\t' && c != '\n' && c != '\r') {
cleanColumnsStr += c;
}
}
// 使用逗号分割列名
size_t pos = 0;
std::string token;
while ((pos = cleanColumnsStr.find(',')) != std::string::npos) {
token = cleanColumnsStr.substr(0, pos);
if (!token.empty()) {
result.cols_set.insert(token); // 保留原始大小写
}
cleanColumnsStr.erase(0, pos + 1);
}
// 添加最后一个列名
if (!cleanColumnsStr.empty()) {
result.cols_set.insert(cleanColumnsStr); // 保留原始大小写
}
// 至少要有一列
if (result.cols_set.empty()) {
result.valid = false;
result.errorMessage = "列密钥轮换需要至少指定一个列名";
}
return result;
}
// 匹配全部轮换 (--ALL)
/**
* 轮换全部
*/
if (std::regex_search(command, matches, allPattern)) {
result.valid = true;
result.type = ROTATE_ALL;
result.tableName = matches[1].str(); // 保留原始大小写
return result;
}
// 如果没有匹配任何模式
result.errorMessage = "无效的DEK轮换命令格式";
return result;
}
// 使用示例函数
void printRotateCommandResult(const RotateCommandResult& result) {
if (!result.valid) {
std::cout << "无效命令: " << result.errorMessage << std::endl;
return;
}
std::cout << "检测到有效命令!" << std::endl;
std::cout << "表名: " << result.tableName << std::endl;
switch (result.type) {
case ROTATE_ALL:
std::cout << "轮换类型: ALL (表和所有列)" << std::endl;
break;
case ROTATE_TABLE:
std::cout << "轮换类型: TABLE (仅表密钥)" << std::endl;
break;
case ROTATE_COLUMNS:
std::cout << "轮换类型: COLUMNS (仅指定列)" << std::endl;
std::cout << "要轮换的列: ";
for (auto it = result.cols_set.begin();it != result.cols_set.end(); ++ it) {
std::cout << *it;
if (it != result.cols_set.begin()) {
std::cout << ", ";
}
}
std::cout << std::endl;
break;
}
}
void connectionDelete() {
// 数据库连接信息
const char *conninfo = "dbname=dekmaster user=dekmaster password=secure_password hostaddr=127.0.0.1 port=5432";
PGconn *conn = PQconnectdb(conninfo);
if (PQstatus(conn) != CONNECTION_OK) {
fprintf(stderr, "Connection to database failed: %s\n", PQerrorMessage(conn));
PQfinish(conn);
return;
}
// 构建删除查询
std::string query = "DELETE FROM dek_store WHERE ";
query.append("username = '").append(DekInterface::getInfoUser()).append("' AND ")
.append("db = '").append(DekInterface::getInfoDb()).append("' AND ")
.append("t = '").append(DekInterface::getInfoTable()).append("'");
query.append(";");
std::cout << query << std::endl;
const char *query_c = query.c_str();
PGresult *res = PQexec(conn, query_c);
// 检查查询状态
if (PQresultStatus(res) != PGRES_COMMAND_OK) {
fprintf(stderr, "Delete query execution failed: %s\n", PQerrorMessage(conn));
PQclear(res);
} else {
PQclear(res); // 释放查询结果
}
PQfinish(conn); // 关闭数据库连接
}
void connectionUpdateDek(RotateCommandResult &cmd) {
// 预先操作
connectionUpdateDek_Init(cmd);
if(!cmd.errorMessage.empty()) {
std::cout << cmd.errorMessage << std::endl;
return;
}
connectionUpdateDek_Update(cmd);
if(!cmd.errorMessage.empty()) {
std::cout << cmd.errorMessage << std::endl;
return;
}
connectionUpdateDek_Final(cmd);
if(!cmd.errorMessage.empty()) {
std::cout << cmd.errorMessage << std::endl;
return;
}
}
/**
* 连接数据库
* 从 map.json 中提取表的所有列为columns
* 从 dek_store 表中提取所有密钥,通过映射得到原列名,记录在 cmd.enc_cols_set 中。
*
* 根据 cmd.type 有三种处理
* 1. ROTATE_ALL: 轮换所有列,那么 cmd.cols_set = all_columns
* 2. ROTATE_TABLE: 轮换表级密钥,那么 cmd.cols_set = columns - cmd.enc_colsset
* 3. ROTATE_COLUMNS: 轮换指定列密钥,那么 cmd.cols_set = col1, col2, ...
*
* 后续通过cmd.cols 构造查询和更新
*/
void connectionUpdateDek_Init(RotateCommandResult &cmd) {
DekInterface::setInfoTable(cmd.tableName);
// 数据库连接信息
const char *conninfo = "dbname=dekmaster user=dekmaster password=secure_password hostaddr=127.0.0.1 port=5432";
PGconn *conn = PQconnectdb(conninfo);
if (PQstatus(conn) != CONNECTION_OK) {
cmd.errorMessage = "Connection to database failed;\n";
fprintf(stderr, "Connection to database failed: %s\n", PQerrorMessage(conn));
PQfinish(conn);
return;
}
std::string file_path = "/etc/encryptsql/map.json";
json j = read_json_from_file(file_path);
std::set<std::string> all_columns = get_all_columns(cmd.tableName, j);
cmd.col_map = get_column_mapping(cmd.tableName,j);
// 移除非密文列
std::cout << "" + cmd.tableName + " 的非密文列有: " ;
for (auto it = all_columns.begin(); it != all_columns.end(); ) {
const auto& col = *it;
if (cmd.col_map[col] == col) {
// 移除非密文列
it = all_columns.erase(it); // erase 会返回下一个元素的迭代器
} else {
// 获得该列的类型
cmd.col_type[col] = get_column_type(cmd.tableName, col, j);
++it; // 只有在没有删除时才增加迭代器
std::cout << col << ":" << cmd.col_type[col] << " ";
}
}
std::cout << std::endl;
// 获取enc_table_name
cmd.enc_tableName = getMappedName(T_STRING_TABLE, cmd.tableName.c_str(), NULL);
std::string username = DekInterface::getInfoUser();
std::string db_name = DekInterface::getInfoDb();
// 执行查询
std::string query1 = "SELECT c, dek FROM dek_store WHERE t = '" + cmd.enc_tableName + "';";
PGresult *res1 = PQexec(conn, query1.c_str());
// 检查查询状态
if (PQresultStatus(res1) != PGRES_TUPLES_OK) {
cmd.errorMessage = "Query execution failed: " + query1 + "\n";
fprintf(stderr, "Query execution failed: %s\n", PQerrorMessage(conn));
PQclear(res1);
PQfinish(conn);
return;
}
// 获取查询结果的行数
int nrows = PQntuples(res1);
// 获取CMK
// std::string cmk = DekAPI::getCurrentCmk();
// 遍历
for (int i = 0; i < nrows; i++) { // 遍历每一行
std::string col_name(PQgetvalue(res1, i, 0)); // 列名
std::string dek(PQgetvalue(res1, i, 1)); // 密钥
std::string dek_for_update; // 更新的密钥
// 解密密钥
// cmk_mapperDecryptDek(dek, cmk); // 解密密钥
KMSInterface::decryptData(dek);
// 判断 如果是 NULL 则为表级密钥 && (轮换表级||轮换全部) -> 需要轮换表密钥
if (PQgetisnull(res1, i, 0)) { // 如果列名是NULL表示这是表级密钥
// 创建新密钥
if (KMSInterface::createDek(dek_for_update, "")) { // 表级密钥使用空列名, 也可以DekAPI::getInfoTable())
KMSInterface::decryptData(dek_for_update);
} else {
fprintf(stderr, "rotate dek failed: can't create new dek.\n");
}
if(cmd.type == ROTATE_TABLE || cmd.type == ROTATE_ALL) {
DekInterface::setDekTableLevelForUpdate(dek_for_update);
}else{
DekInterface::setDekTableLevelForUpdate(dek);
}
}else{
// 从密文列名col_name获取到明文列名如col1_AES),然后作截断得到 col1
std::string plain = cmd.col_map[col_name];
plain = plain.substr(0,plain.rfind('_'));
// 创建新密钥,使用加密后的列名
if (KMSInterface::createDek(dek_for_update, col_name)) { // 使用加密后的列名
KMSInterface::decryptData(dek_for_update);
} else {
fprintf(stderr, "rotate dek failed: can't create new dek.\n");
}
if(cmd.type == ROTATE_TABLE){ // 轮换表级密钥
all_columns.erase(plain);
DekInterface::setDekColLevelForUpdate(col_name, dek);
} else if(cmd.type == ROTATE_ALL){ // 轮换所有密钥
cmd.enc_cols_set.insert(col_name); // 存储密文列名
DekInterface::setDekColLevelForUpdate(col_name, dek_for_update);
} else if(cmd.type == ROTATE_COLUMNS){ // 轮换指定列密钥
if(cmd.cols_set.find(plain) != cmd.cols_set.end()){ // 判断dek_store中的列是否在命令中是的话用新密钥不是则用旧密钥
cmd.enc_cols_set.insert(col_name); // 存储密文列名
DekInterface::setDekColLevelForUpdate(col_name, dek_for_update);
}else{
DekInterface::setDekColLevelForUpdate(col_name, dek);
}
}
}
}
// cmd.cols 命令中的列名明文cmd.enc_colsset 命令中的列名(密文)
if(cmd.type == ROTATE_ALL || cmd.type == ROTATE_TABLE){
cmd.cols_set = all_columns;
}else if(cmd.type == ROTATE_COLUMNS){
// skip
// todo 加一个 如果不是合理的密钥就报错的机制
}
std::cout << "需要轮换的列: ";
for(auto col: cmd.cols_set){
std::cout << col << " ";
}
std::cout << std::endl;
PQfinish(conn);
}
/**
* 连接数据库
* cmd.cols 构造查询
* cmd.cols 构造更新 Update的逻辑会在有dekForUpdate的情况下选取dekForUpdate所以这里无需切换
*/
void connectionUpdateDek_Update(RotateCommandResult &cmd) {
// 数据库连接信息
const char *conninfo = "dbname=postgres user=postgres hostaddr=127.0.0.1 port=5432";
PGconn *conn = PQconnectdb(conninfo);
if (PQstatus(conn) != CONNECTION_OK) {
cmd.errorMessage = "Connection to database failed;\n";
fprintf(stderr, "Connection to database failed: %s\n", PQerrorMessage(conn));
PQfinish(conn);
return;
}
std::string table_name = cmd.tableName;
// 开始事务
PQexec(conn, "BEGIN");
// 1. 开启密态功能
std::string enc_on = "enc on;";
PGresult *res_enc = PQexec(conn, enc_on.c_str());
if (PQresultStatus(res_enc) != PGRES_TUPLES_OK) {
cmd.errorMessage = "Enable encryption failed: " + enc_on + "\n";
fprintf(stderr, "Enable encryption failed: %s\n", PQerrorMessage(conn));
PQclear(res_enc);
PQexec(conn, "ROLLBACK");
PQfinish(conn);
return;
}
std::cout << enc_on << std::endl;
PQclear(res_enc);
// 2. 获取表的总行数和最大rotate_id
std::string count_query = "SELECT COUNT(*), MAX(rotate_id) FROM " + table_name + ";";
PGresult *res_count = PQexec(conn, count_query.c_str());
if (PQresultStatus(res_count) != PGRES_TUPLES_OK) {
cmd.errorMessage = "Query execution failed: " + count_query + "\n";
fprintf(stderr, "Count query failed: %s\n", PQerrorMessage(conn));
PQclear(res_count);
PQfinish(conn);
return;
}
int total_rows = atoi(PQgetvalue(res_count, 0, 0));
int max_id = atoi(PQgetvalue(res_count, 0, 1));
PQclear(res_count);
std::cout << "" << table_name << " 共有 " << total_rows << " 行数据, 最大ID为 " << max_id << std::endl;
// 如果表为空,直接返回
if (total_rows == 0) {
std::cout << "表为空,无需更新" << std::endl;
PQfinish(conn);
return;
}
// 构建需要查询和更新的列字符串
std::string columns_str;
for (auto it = cmd.cols_set.begin(); it != cmd.cols_set.end(); ++it){
if (it != cmd.cols_set.begin()) {
columns_str += ", ";
}
columns_str += *it;
}
// 批处理大小
const int BATCH_SIZE = 1000;
int current_id = 0;
int processed_rows = 0;
// 分批处理数据
while (processed_rows < total_rows) {
// 查询一批数据
std::string select_query = "SELECT rotate_id, " + columns_str +
" FROM " + table_name +
" WHERE rotate_id > " + std::to_string(current_id) +
" ORDER BY rotate_id LIMIT " + std::to_string(BATCH_SIZE) + ";";
std::cout << "执行查询: " << select_query << std::endl;
PGresult *res_select = PQexec(conn, select_query.c_str());
if (PQresultStatus(res_select) != PGRES_TUPLES_OK) {
cmd.errorMessage = "Query execution failed: " + select_query + "\n";
fprintf(stderr, "SELECT query failed: %s\n", PQerrorMessage(conn));
PQclear(res_select);
PQexec(conn, "ROLLBACK");
PQfinish(conn);
return;
}
int batch_rows = PQntuples(res_select);
if (batch_rows == 0) {
break;
}
std::cout << "本批次处理 " << batch_rows << " 行数据" << std::endl;
// 逐行更新数据
for (int i = 0; i < batch_rows; ++i) {
int row_id = atoi(PQgetvalue(res_select, i, 0));
current_id = row_id; // 更新当前处理的ID
// 构建更新语句 - 对每一列都使用原值更新,触发密钥轮换
std::string update_query = "UPDATE " + table_name + " SET ";
size_t j = 0;
for(auto it = cmd.cols_set.begin(); it != cmd.cols_set.end(); ++it,++ j) {
std::string col_value = PQgetvalue(res_select, i, j+1 ); //
std::string colname = *it;
// 检查是否为NULL
if (PQgetisnull(res_select, i, j+1 )) {
update_query += colname + " = NULL";
} else {
// 获取列的类型信息
std::string col_type = cmd.col_type[colname];
// 根据PostgreSQL类型OID判断列类型
bool is_text = (col_type == "text");
if (!is_text) {
// 数值类型不加引号
update_query += colname + " = " + col_value;
} else {
update_query += colname + " = '" + col_value + "'";
}
}
if (j < cmd.cols_set.size() - 1) {
update_query += ", ";
}
}
update_query += " WHERE rotate_id = " + std::to_string(row_id) + ";";
// 执行更新
PGresult *res_update = PQexec(conn, update_query.c_str());
if (PQresultStatus(res_update) != PGRES_COMMAND_OK) {
fprintf(stderr, "UPDATE query failed for row %d: %s\n", row_id, PQerrorMessage(conn));
PQclear(res_update);
PQclear(res_select);
PQexec(conn, "ROLLBACK");
PQfinish(conn);
return;
}
PQclear(res_update);
}
processed_rows += batch_rows;
std::cout << "已处理 " << processed_rows << " / " << total_rows << " 行数据" << std::endl;
PQclear(res_select);
// 提交事务
PGresult *res_commit = PQexec(conn, "COMMIT");
if (PQresultStatus(res_commit) != PGRES_COMMAND_OK) {
fprintf(stderr, "COMMIT transaction failed: %s\n", PQerrorMessage(conn));
PQclear(res_commit);
PQfinish(conn);
return;
}
PQclear(res_commit);
}
PQfinish(conn);
}
/**
* 连接数据库
* 开始事务
* 更新数据库中的表级和列级 DEK 记录
* 提交事务
* 清理临时资源
*/
void connectionUpdateDek_Final(RotateCommandResult &cmd) {
// 数据库连接信息
const char *conninfo = "dbname=dekmaster user=dekmaster password=secure_password hostaddr=127.0.0.1 port=5432";
PGconn *conn = PQconnectdb(conninfo);
if (PQstatus(conn) != CONNECTION_OK) {
cmd.errorMessage = "Connection to database failed;\n";
fprintf(stderr, "Connection to database failed: %s\n", PQerrorMessage(conn));
PQfinish(conn);
return;
}
// 开始事务
PGresult *res_begin = PQexec(conn, "BEGIN");
if (PQresultStatus(res_begin) != PGRES_COMMAND_OK) {
fprintf(stderr, "BEGIN transaction failed: %s\n", PQerrorMessage(conn));
PQclear(res_begin);
PQfinish(conn);
return;
}
PQclear(res_begin);
// 获取待更新的表级密钥
std::string table_dek_for_update;
DekInterface::getDekTableLevelForUpdate(table_dek_for_update);
// 获取CMK
// 如果有表级密钥需要更新
if (cmd.type != ROTATE_COLUMNS && !table_dek_for_update.empty()) {
// 加密表级密钥
KMSInterface::encryptData(table_dek_for_update);
if (table_dek_for_update.empty()) {
fprintf(stderr, "Encrypt table DEK failed\n");
PQexec(conn, "ROLLBACK");
PQfinish(conn);
return;
}
// 更新表级密钥
std::string update_table_dek = "UPDATE dek_store SET dek = '" + table_dek_for_update +
"' WHERE t = '" + cmd.enc_tableName + "' AND c IS NULL;";
PGresult *res_update_table = PQexec(conn, update_table_dek.c_str());
if (PQresultStatus(res_update_table) != PGRES_COMMAND_OK) {
fprintf(stderr, "Update table DEK failed: %s\n", PQerrorMessage(conn));
PQclear(res_update_table);
PQexec(conn, "ROLLBACK");
PQfinish(conn);
return;
}
PQclear(res_update_table);
std::cout << "表级密钥更新成功" << std::endl;
}
if(!cmd.cols_set.empty() && cmd.type != ROTATE_TABLE){
// 获取所有待更新的列级密钥
std::unordered_map<std::string, std::string> column_deks;
DekInterface::getAllDekColLevelForUpdate(column_deks);
// 更新列级密钥
for(const auto& enc_col_name: cmd.enc_cols_set){
std::string dek = column_deks[enc_col_name];
if (dek.empty()) {
fprintf(stderr, "Encrypt column DEK failed for column %s\n", enc_col_name.c_str());
continue;
}
// 加密列级密钥
// cmk_mapperEncryptDek(dek, cmk);
KMSInterface::encryptData(dek);
// 更新列级密钥
std::string update_col_dek = "UPDATE dek_store SET dek = '" + dek +
"' WHERE t = '" + cmd.enc_tableName +
"' AND c = '" + enc_col_name + "';";
PGresult *res_update_col = PQexec(conn, update_col_dek.c_str());
if (PQresultStatus(res_update_col) != PGRES_COMMAND_OK) {
fprintf(stderr, "Update column DEK failed for column %s: %s\n",
enc_col_name.c_str(), PQerrorMessage(conn));
PQclear(res_update_col);
continue;
}
PQclear(res_update_col);
std::cout << "" << enc_col_name << " 密钥更新成功" << std::endl;
}
}
// 提交事务
PGresult *res_commit = PQexec(conn, "COMMIT");
if (PQresultStatus(res_commit) != PGRES_COMMAND_OK) {
fprintf(stderr, "COMMIT transaction failed: %s\n", PQerrorMessage(conn));
PQclear(res_commit);
PQexec(conn, "ROLLBACK");
PQfinish(conn);
return;
}
PQclear(res_commit);
std::cout << "密钥轮换完成" << std::endl;
PQfinish(conn);
}
void connectionInsertTest() {
// 数据库连接信息
const char *conninfo1 = "dbname=dekmaster user=dekmaster password=secure_password hostaddr=127.0.0.1 port=5432";
PGconn *conn1 = PQconnectdb(conninfo1);
if (PQstatus(conn1) != CONNECTION_OK) {
fprintf(stderr, "Connection to database failed: %s\n", PQerrorMessage(conn1));
PQfinish(conn1);
}
std::string table_dek;
if (KMSInterface::createDek(table_dek, "")) { // 表级密钥使用空列名
std::string query1 = "insert into dek_store values";
std::string tuple1 = "(";
tuple1.append("'").append(DekInterface::getInfoUser()).append("',")
.append("'").append(DekInterface::getInfoDb()).append("',")
.append("'").append(DekInterface::getInfoTable()).append("',")
.append("NULL,")
.append("'").append(table_dek).append("')");
// 获取列信息并插入列级密钥
std::string tmp1 = DekInterface::getInfoCol();
if (!tmp1.empty()) {
tuple1.append(",");
}
query1.append(tuple1);
// 处理所有列密钥
while (!tmp1.empty()) {
std::string col_dek;
if(KMSInterface::createDek(col_dek, tmp1)) {
std::string tuple_col = "(";
tuple_col.append("'").append(DekInterface::getInfoUser()).append("',")
.append("'").append(DekInterface::getInfoDb()).append("',")
.append("'").append(DekInterface::getInfoTable()).append("',")
.append("'").append(tmp1).append("',")
.append("'").append(col_dek).append("')");
tmp1 = DekInterface::getInfoCol();
if (!tmp1.empty()) {
tuple_col.append(",");
}
query1.append(tuple_col);
}
}
query1.append(";");
std::cout << query1 << std::endl;
const char *query1_c = query1.c_str();
// 执行查询
PGresult *res1 = PQexec(conn1, query1_c);
// 检查查询状态
if (PQresultStatus(res1) != PGRES_COMMAND_OK) {
fprintf(stderr, "Query execution failed: %s\n", PQerrorMessage(conn1));
PQclear(res1);
} else {
PQclear(res1); // 释放查询结果
}
}
PQfinish(conn1); // 关闭数据库连接
}
void connectionSelectTest() {
const char *conninfo1 = "dbname=dekmaster user=dekmaster password=secure_password hostaddr=127.0.0.1 port=5432";
PGconn *conn1 = PQconnectdb(conninfo1);
if (PQstatus(conn1) != CONNECTION_OK) {
fprintf(stderr, "Connection to database failed: %s\n", PQerrorMessage(conn1));
PQfinish(conn1);
}
// 执行查询
std::string query1 = "SELECT * FROM dek_store WHERE t = \'" + DekInterface::getInfoTable() + "\';";
PGresult *res1 = PQexec(conn1, query1.c_str());
// 检查查询状态
if (PQresultStatus(res1) != PGRES_TUPLES_OK) {
fprintf(stderr, "Query execution failed: %s\n", PQerrorMessage(conn1));
PQclear(res1);
PQfinish(conn1);
}
int nrows = PQntuples(res1); // 获取查询结果的行数
int ncols = PQnfields(res1); // 获取查询结果的列数
for (int i = 0; i < nrows; i++) { // 遍历每一行
for (int j = 0; j < ncols; j++) { // 遍历每一列
char *value = PQgetvalue(res1, i, j); // 获取第 i 行,第 j 列的值
std::cout << value << "\t";
}
std::cout << std::endl; // 换行表示下一行
}
PQclear(res1); // 释放查询结果
PQfinish(conn1); // 关闭数据库连接
}
void connectionSelect() {
const char *conninfo1 = "dbname=dekmaster user=dekmaster password=secure_password hostaddr=127.0.0.1 port=5432";
PGconn *conn1 = PQconnectdb(conninfo1);
if (PQstatus(conn1) != CONNECTION_OK) {
fprintf(stderr, "Connection to database failed: %s\n", PQerrorMessage(conn1));
PQfinish(conn1);
return;
}
// 执行查询
std::string table_name = DekInterface::getInfoTable();
std::string query1 = "SELECT c, dek FROM dek_store WHERE t = '" + table_name + "';";
PGresult *res1 = PQexec(conn1, query1.c_str());
// 检查查询状态
if (PQresultStatus(res1) != PGRES_TUPLES_OK) {
fprintf(stderr, "Query execution failed: %s\n", PQerrorMessage(conn1));
PQclear(res1);
PQfinish(conn1);
return;
}
// 获取查询结果的行数
int nrows = PQntuples(res1);
// 获取CMK
for (int i = 0; i < nrows; i++) { // 遍历每一行
std::string col_name(PQgetvalue(res1, i, 0)); // 列名
std::string dek(PQgetvalue(res1, i, 1)); // 对应的密钥
// 解密密钥
KMSInterface::decryptData(dek);
if (PQgetisnull(res1, i, 0)) { // 如果列名是NULL表示这是表级密钥
DekInterface::setDekTableLevel(dek);
} else { // 否则是列级密钥
DekInterface::setDekColLevel(col_name, dek);
}
}
PQclear(res1); // 释放查询结果
PQfinish(conn1); // 关闭数据库连接
}
// 读取 JSON 文件
json read_json_from_file(const std::string& file_path) {
std::ifstream file(file_path);
if (!file.is_open()) {
std::cerr << "无法打开文件:" << file_path << std::endl;
exit(1);
}
json j;
file >> j;
return j;
}
// 获取表中所有列
std::vector<std::string> get_columns(const std::string& table_name, const json& j) {
std::vector<std::string> columns;
try {
// 获取指定表名的列信息
auto columns_info = j["schema"]["public"][table_name]["columns"];
for (const auto& col : columns_info) {
columns.push_back(col);
}
} catch (const std::exception& e) {
std::cerr << "获取列时发生错误:" << e.what() << std::endl;
}
return columns;
}
// 获取表中所有列
std::set<std::string> get_all_columns(const std::string& table_name, const json& j) {
std::set<std::string> cols;
try {
// 定位到 schema.public.{table_name}.columns
auto cols_array = j.at("schema").at("public").at(table_name).at("columns");
for (const auto& col : cols_array) {
cols.insert(col.get<std::string>());
}
} catch (const std::exception& e) {
std::cerr << "获取所有列名时发生错误:" << e.what() << std::endl;
}
return cols;
}
// 获取所有密文列名 ↔ 明文列名 之间的映射,存储在一个 std::map 中
std::unordered_map<std::string, std::string> get_column_mapping(const std::string& table_name, const json& j) {
std::unordered_map<std::string, std::string> mapping;
try {
// 定位到 schema.public.{table_name}.map
auto map_info = j.at("schema").at("public").at(table_name).at("map");
for (auto it = map_info.begin(); it != map_info.end(); ++it) {
// key: 列名1可能是加密后的列名或明文标识
// value: 列名2对应的另一端
mapping[it.key()] = it.value().get<std::string>();
}
} catch (const std::exception& e) {
std::cerr << "获取列映射时发生错误:" << e.what() << std::endl;
}
return mapping;
}
// 获取列的数据类型
std::string get_column_type(const std::string& table_name, const std::string& column_name, const json& j) {
try {
// 获取表的信息
auto table_info = j["schema"]["public"][table_name];
// 检查列是否存在
if (table_info.contains(column_name)) {
// 返回列的类型
if (table_info[column_name].contains("type")) {
return table_info[column_name]["type"];
}
}
} catch (const std::exception& e) {
std::cerr << "获取列类型时发生错误:" << e.what() << std::endl;
}
return "";
}
/**
* 执行数据库操作的辅助函数
* @param sql SQL语句
* @param errorMsg 错误描述(用于日志)
* @return 0表示成功-1表示失败
*/
int executeSQL(const char *sql, const char *errorMsg) {
PGconn *conn = PQconnectdb(DB_CONNINFO);
PGresult *res = NULL;
int result = -1;
if (PQstatus(conn) != CONNECTION_OK) {
fprintf(stderr, "Connection to database failed: %s\n", PQerrorMessage(conn));
PQfinish(conn);
return -1;
}
res = PQexec(conn, sql);
if (PQresultStatus(res) == PGRES_COMMAND_OK) {
result = 0;
printf("Database operation succeeded: %s\n", errorMsg);
} else {
fprintf(stderr, "Database operation failed (%s): %s\n", errorMsg, PQerrorMessage(conn));
fprintf(stderr, "SQL: %s\n", sql);
}
PQclear(res);
PQfinish(conn);
return result;
}
/**
* 检查表是否存在
* @param tableName 表名
* @return 1表示存在0表示不存在-1表示查询失败
*/
int tableExists(const char *tableName) {
PGconn *conn = PQconnectdb(DB_CONNINFO);
PGresult *res = NULL;
int result = -1;
char query[256];
if (PQstatus(conn) != CONNECTION_OK) {
fprintf(stderr, "Connection to database failed: %s\n", PQerrorMessage(conn));
PQfinish(conn);
return -1;
}
snprintf(query, sizeof(query),
"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = '%s')",
tableName);
res = PQexec(conn, query);
if (PQresultStatus(res) == PGRES_TUPLES_OK && PQntuples(res) > 0) {
const char *exists = PQgetvalue(res, 0, 0);
result = (strcmp(exists, "t") == 0) ? 1 : 0;
} else {
fprintf(stderr, "Failed to check table existence: %s\n", PQerrorMessage(conn));
}
PQclear(res);
PQfinish(conn);
return result;
}
/**
* 备份dek_store表
* @return 0表示成功-1表示失败
*/
int backupDekStore() {
char sql[512];
int result = 0;
// 检查原表是否存在
int originalExists = tableExists(ORIGINAL_TABLE_NAME);
if (originalExists < 0) {
fprintf(stderr, "Failed to check if original table exists\n");
return -1;
}
if (originalExists == 0) {
printf("Original table %s does not exist, skipping database backup\n", ORIGINAL_TABLE_NAME);
return 0;
}
// 删除可能存在的备份表
snprintf(sql, sizeof(sql), "DROP TABLE IF EXISTS %s", BACKUP_TABLE_NAME);
if (executeSQL(sql, "drop existing backup table") != 0) {
return -1;
}
// 创建备份表并复制数据
snprintf(sql, sizeof(sql),
"CREATE TABLE %s AS SELECT * FROM %s",
BACKUP_TABLE_NAME, ORIGINAL_TABLE_NAME);
if (executeSQL(sql, "create backup table") != 0) {
return -1;
}
printf("Database table %s backed up successfully to %s\n", ORIGINAL_TABLE_NAME, BACKUP_TABLE_NAME);
return 0;
}
/**
* 删除dek_store备份表
* @return 0表示成功-1表示失败
*/
int deleteDekStoreBackup() {
char sql[256];
// 检查备份表是否存在
int backupExists = tableExists(BACKUP_TABLE_NAME);
if (backupExists < 0) {
fprintf(stderr, "Failed to check if backup table exists\n");
return -1;
}
if (backupExists == 0) {
printf("Backup table %s does not exist\n", BACKUP_TABLE_NAME);
return 0;
}
// 删除备份表
snprintf(sql, sizeof(sql), "DROP TABLE %s", BACKUP_TABLE_NAME);
if (executeSQL(sql, "delete backup table") != 0) {
return -1;
}
printf("Database backup table %s deleted successfully\n", BACKUP_TABLE_NAME);
return 0;
}
/**
* 从备份恢复dek_store表
* @return 0表示成功-1表示失败
*/
int restoreDekStore() {
char sql[512];
// 检查备份表是否存在
int backupExists = tableExists(BACKUP_TABLE_NAME);
if (backupExists < 0) {
fprintf(stderr, "Failed to check if backup table exists\n");
return -1;
}
if (backupExists == 0) {
printf("No backup table %s found for rollback\n", BACKUP_TABLE_NAME);
return 0;
}
// 删除当前表(如果存在)
snprintf(sql, sizeof(sql), "DROP TABLE IF EXISTS %s", ORIGINAL_TABLE_NAME);
if (executeSQL(sql, "drop current table for restore") != 0) {
return -1;
}
// 重命名备份表为原表名
snprintf(sql, sizeof(sql),
"ALTER TABLE %s RENAME TO %s",
BACKUP_TABLE_NAME, ORIGINAL_TABLE_NAME);
if (executeSQL(sql, "restore table from backup") != 0) {
return -1;
}
printf("Database table %s restored from backup successfully\n", ORIGINAL_TABLE_NAME);
return 0;
}
}