1000 lines
35 KiB
C++
Executable File
1000 lines
35 KiB
C++
Executable File
#include "kms/kms_interface.hpp"
|
||
#include "kmsAdapter/dek_interface.hpp"
|
||
#include <fstream> // 必须包含这个头文件
|
||
|
||
namespace DekInterface{
|
||
|
||
RotateCommandResult parseRotateCommand(const std::string& command,const std::string &user_name,const std::string &db_name) {
|
||
RotateCommandResult result;
|
||
result.valid = false;
|
||
|
||
std::string current_cmk;
|
||
|
||
DekInterface::setInfoUser(user_name);
|
||
DekInterface::setInfoDb(db_name);
|
||
|
||
// 创建命令的大写副本用于不区分大小写的匹配,但保留原始命令用于提取名称
|
||
std::string upperCommand = command;
|
||
for (char& c : upperCommand) {
|
||
c = toupper(c);
|
||
}
|
||
|
||
// 新的命令格式正则表达式
|
||
std::regex tkPattern(R"(^\s*ROTATE\s+DEK\s+NOW\s+`?(\w+)`?\s+-TK\s*;?\s*$)", std::regex::icase);
|
||
std::regex ckPattern(R"(^\s*ROTATE\s+DEK\s+NOW\s+`?(\w+)`?\s+-CK\s+([\w\s,]+)\s*;?\s*$)", std::regex::icase);
|
||
std::regex allPattern(R"(^\s*ROTATE\s+DEK\s+NOW\s+`?(\w+)`?\s+-ALL\s*;?\s*$)", std::regex::icase);
|
||
|
||
std::smatch matches;
|
||
|
||
// 匹配表密钥轮换 (--TK)
|
||
/**
|
||
* 只轮换表密钥
|
||
*/
|
||
if (std::regex_search(command, matches, tkPattern)) {
|
||
result.valid = true;
|
||
result.type = ROTATE_TABLE;
|
||
result.tableName = matches[1].str(); // 保留原始大小写
|
||
return result;
|
||
}
|
||
|
||
// 匹配列密钥轮换 (--CK)
|
||
/**
|
||
* 轮换选中列密钥
|
||
*/
|
||
if (std::regex_search(command, matches, ckPattern)) {
|
||
result.valid = true;
|
||
result.type = ROTATE_COLUMNS;
|
||
result.tableName = matches[1].str(); // 保留原始大小写
|
||
|
||
// 解析列名列表
|
||
std::string columnsStr = matches[2].str(); // 保留原始大小写
|
||
|
||
// 移除所有空格
|
||
std::string cleanColumnsStr;
|
||
for (char c : columnsStr) {
|
||
if (c != ' ' && c != '\t' && c != '\n' && c != '\r') {
|
||
cleanColumnsStr += c;
|
||
}
|
||
}
|
||
|
||
// 使用逗号分割列名
|
||
size_t pos = 0;
|
||
std::string token;
|
||
while ((pos = cleanColumnsStr.find(',')) != std::string::npos) {
|
||
token = cleanColumnsStr.substr(0, pos);
|
||
|
||
if (!token.empty()) {
|
||
result.cols_set.insert(token); // 保留原始大小写
|
||
}
|
||
cleanColumnsStr.erase(0, pos + 1);
|
||
}
|
||
|
||
// 添加最后一个列名
|
||
if (!cleanColumnsStr.empty()) {
|
||
result.cols_set.insert(cleanColumnsStr); // 保留原始大小写
|
||
}
|
||
|
||
// 至少要有一列
|
||
if (result.cols_set.empty()) {
|
||
result.valid = false;
|
||
result.errorMessage = "列密钥轮换需要至少指定一个列名";
|
||
}
|
||
|
||
return result;
|
||
}
|
||
|
||
// 匹配全部轮换 (--ALL)
|
||
/**
|
||
* 轮换全部
|
||
*/
|
||
if (std::regex_search(command, matches, allPattern)) {
|
||
result.valid = true;
|
||
result.type = ROTATE_ALL;
|
||
result.tableName = matches[1].str(); // 保留原始大小写
|
||
return result;
|
||
}
|
||
|
||
// 如果没有匹配任何模式
|
||
result.errorMessage = "无效的DEK轮换命令格式";
|
||
return result;
|
||
}
|
||
|
||
// 使用示例函数
|
||
void printRotateCommandResult(const RotateCommandResult& result) {
|
||
if (!result.valid) {
|
||
std::cout << "无效命令: " << result.errorMessage << std::endl;
|
||
return;
|
||
}
|
||
|
||
std::cout << "检测到有效命令!" << std::endl;
|
||
std::cout << "表名: " << result.tableName << std::endl;
|
||
|
||
switch (result.type) {
|
||
case ROTATE_ALL:
|
||
std::cout << "轮换类型: ALL (表和所有列)" << std::endl;
|
||
break;
|
||
case ROTATE_TABLE:
|
||
std::cout << "轮换类型: TABLE (仅表密钥)" << std::endl;
|
||
break;
|
||
case ROTATE_COLUMNS:
|
||
std::cout << "轮换类型: COLUMNS (仅指定列)" << std::endl;
|
||
std::cout << "要轮换的列: ";
|
||
for (auto it = result.cols_set.begin();it != result.cols_set.end(); ++ it) {
|
||
std::cout << *it;
|
||
if (it != result.cols_set.begin()) {
|
||
std::cout << ", ";
|
||
}
|
||
}
|
||
std::cout << std::endl;
|
||
break;
|
||
}
|
||
}
|
||
|
||
void connectionDelete() {
|
||
// 数据库连接信息
|
||
const char *conninfo = "dbname=dekmaster user=dekmaster password=secure_password hostaddr=127.0.0.1 port=5432";
|
||
PGconn *conn = PQconnectdb(conninfo);
|
||
if (PQstatus(conn) != CONNECTION_OK) {
|
||
fprintf(stderr, "Connection to database failed: %s\n", PQerrorMessage(conn));
|
||
PQfinish(conn);
|
||
return;
|
||
}
|
||
|
||
// 构建删除查询
|
||
std::string query = "DELETE FROM dek_store WHERE ";
|
||
query.append("username = '").append(DekInterface::getInfoUser()).append("' AND ")
|
||
.append("db = '").append(DekInterface::getInfoDb()).append("' AND ")
|
||
.append("t = '").append(DekInterface::getInfoTable()).append("'");
|
||
|
||
query.append(";");
|
||
std::cout << query << std::endl;
|
||
const char *query_c = query.c_str();
|
||
|
||
PGresult *res = PQexec(conn, query_c);
|
||
// 检查查询状态
|
||
if (PQresultStatus(res) != PGRES_COMMAND_OK) {
|
||
fprintf(stderr, "Delete query execution failed: %s\n", PQerrorMessage(conn));
|
||
PQclear(res);
|
||
} else {
|
||
PQclear(res); // 释放查询结果
|
||
}
|
||
|
||
PQfinish(conn); // 关闭数据库连接
|
||
}
|
||
|
||
void connectionUpdateDek(RotateCommandResult &cmd) {
|
||
// 预先操作
|
||
connectionUpdateDek_Init(cmd);
|
||
if(!cmd.errorMessage.empty()) {
|
||
std::cout << cmd.errorMessage << std::endl;
|
||
return;
|
||
}
|
||
connectionUpdateDek_Update(cmd);
|
||
if(!cmd.errorMessage.empty()) {
|
||
std::cout << cmd.errorMessage << std::endl;
|
||
return;
|
||
}
|
||
connectionUpdateDek_Final(cmd);
|
||
if(!cmd.errorMessage.empty()) {
|
||
std::cout << cmd.errorMessage << std::endl;
|
||
return;
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 连接数据库
|
||
* 从 map.json 中提取表的所有列为columns
|
||
* 从 dek_store 表中提取所有密钥,通过映射得到原列名,记录在 cmd.enc_cols_set 中。
|
||
*
|
||
* 根据 cmd.type 有三种处理
|
||
* 1. ROTATE_ALL: 轮换所有列,那么 cmd.cols_set = all_columns
|
||
* 2. ROTATE_TABLE: 轮换表级密钥,那么 cmd.cols_set = columns - cmd.enc_colsset
|
||
* 3. ROTATE_COLUMNS: 轮换指定列密钥,那么 cmd.cols_set = col1, col2, ...
|
||
*
|
||
* 后续通过cmd.cols 构造查询和更新
|
||
*/
|
||
void connectionUpdateDek_Init(RotateCommandResult &cmd) {
|
||
DekInterface::setInfoTable(cmd.tableName);
|
||
|
||
// 数据库连接信息
|
||
const char *conninfo = "dbname=dekmaster user=dekmaster password=secure_password hostaddr=127.0.0.1 port=5432";
|
||
PGconn *conn = PQconnectdb(conninfo);
|
||
if (PQstatus(conn) != CONNECTION_OK) {
|
||
cmd.errorMessage = "Connection to database failed;\n";
|
||
fprintf(stderr, "Connection to database failed: %s\n", PQerrorMessage(conn));
|
||
PQfinish(conn);
|
||
return;
|
||
}
|
||
|
||
std::string file_path = "/etc/encryptsql/map.json";
|
||
json j = read_json_from_file(file_path);
|
||
|
||
std::set<std::string> all_columns = get_all_columns(cmd.tableName, j);
|
||
cmd.col_map = get_column_mapping(cmd.tableName,j);
|
||
|
||
// 移除非密文列
|
||
std::cout << "表 " + cmd.tableName + " 的非密文列有: " ;
|
||
for (auto it = all_columns.begin(); it != all_columns.end(); ) {
|
||
const auto& col = *it;
|
||
if (cmd.col_map[col] == col) {
|
||
// 移除非密文列
|
||
it = all_columns.erase(it); // erase 会返回下一个元素的迭代器
|
||
} else {
|
||
// 获得该列的类型
|
||
cmd.col_type[col] = get_column_type(cmd.tableName, col, j);
|
||
++it; // 只有在没有删除时才增加迭代器
|
||
std::cout << col << ":" << cmd.col_type[col] << " ";
|
||
}
|
||
}
|
||
std::cout << std::endl;
|
||
|
||
// 获取enc_table_name
|
||
cmd.enc_tableName = getMappedName(T_STRING_TABLE, cmd.tableName.c_str(), NULL);
|
||
std::string username = DekInterface::getInfoUser();
|
||
std::string db_name = DekInterface::getInfoDb();
|
||
|
||
// 执行查询
|
||
std::string query1 = "SELECT c, dek FROM dek_store WHERE t = '" + cmd.enc_tableName + "';";
|
||
PGresult *res1 = PQexec(conn, query1.c_str());
|
||
|
||
// 检查查询状态
|
||
if (PQresultStatus(res1) != PGRES_TUPLES_OK) {
|
||
cmd.errorMessage = "Query execution failed: " + query1 + "\n";
|
||
fprintf(stderr, "Query execution failed: %s\n", PQerrorMessage(conn));
|
||
PQclear(res1);
|
||
PQfinish(conn);
|
||
return;
|
||
}
|
||
|
||
// 获取查询结果的行数
|
||
int nrows = PQntuples(res1);
|
||
|
||
// 获取CMK
|
||
// std::string cmk = DekAPI::getCurrentCmk();
|
||
|
||
// 遍历
|
||
for (int i = 0; i < nrows; i++) { // 遍历每一行
|
||
std::string col_name(PQgetvalue(res1, i, 0)); // 列名
|
||
std::string dek(PQgetvalue(res1, i, 1)); // 密钥
|
||
std::string dek_for_update; // 更新的密钥
|
||
|
||
// 解密密钥
|
||
// cmk_mapperDecryptDek(dek, cmk); // 解密密钥
|
||
KMSInterface::decryptData(dek);
|
||
|
||
// 判断 如果是 NULL 则为表级密钥 && (轮换表级||轮换全部) -> 需要轮换表密钥
|
||
if (PQgetisnull(res1, i, 0)) { // 如果列名是NULL,表示这是表级密钥
|
||
// 创建新密钥
|
||
if (KMSInterface::createDek(dek_for_update, "")) { // 表级密钥使用空列名, 也可以DekAPI::getInfoTable())
|
||
KMSInterface::decryptData(dek_for_update);
|
||
} else {
|
||
fprintf(stderr, "rotate dek failed: can't create new dek.\n");
|
||
}
|
||
|
||
if(cmd.type == ROTATE_TABLE || cmd.type == ROTATE_ALL) {
|
||
DekInterface::setDekTableLevelForUpdate(dek_for_update);
|
||
}else{
|
||
DekInterface::setDekTableLevelForUpdate(dek);
|
||
}
|
||
}else{
|
||
// 从密文列名col_name获取到明文列名如(col1_AES),然后作截断得到 col1
|
||
std::string plain = cmd.col_map[col_name];
|
||
plain = plain.substr(0,plain.rfind('_'));
|
||
|
||
// 创建新密钥,使用加密后的列名
|
||
if (KMSInterface::createDek(dek_for_update, col_name)) { // 使用加密后的列名
|
||
KMSInterface::decryptData(dek_for_update);
|
||
} else {
|
||
fprintf(stderr, "rotate dek failed: can't create new dek.\n");
|
||
}
|
||
|
||
if(cmd.type == ROTATE_TABLE){ // 轮换表级密钥
|
||
all_columns.erase(plain);
|
||
DekInterface::setDekColLevelForUpdate(col_name, dek);
|
||
} else if(cmd.type == ROTATE_ALL){ // 轮换所有密钥
|
||
cmd.enc_cols_set.insert(col_name); // 存储密文列名
|
||
DekInterface::setDekColLevelForUpdate(col_name, dek_for_update);
|
||
} else if(cmd.type == ROTATE_COLUMNS){ // 轮换指定列密钥
|
||
if(cmd.cols_set.find(plain) != cmd.cols_set.end()){ // 判断dek_store中的列是否在命令中,是的话用新密钥,不是则用旧密钥
|
||
cmd.enc_cols_set.insert(col_name); // 存储密文列名
|
||
DekInterface::setDekColLevelForUpdate(col_name, dek_for_update);
|
||
}else{
|
||
DekInterface::setDekColLevelForUpdate(col_name, dek);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// cmd.cols 命令中的列名(明文)cmd.enc_colsset 命令中的列名(密文)
|
||
if(cmd.type == ROTATE_ALL || cmd.type == ROTATE_TABLE){
|
||
cmd.cols_set = all_columns;
|
||
}else if(cmd.type == ROTATE_COLUMNS){
|
||
// skip
|
||
// todo 加一个 如果不是合理的密钥就报错的机制
|
||
}
|
||
|
||
std::cout << "需要轮换的列: ";
|
||
for(auto col: cmd.cols_set){
|
||
std::cout << col << " ";
|
||
}
|
||
std::cout << std::endl;
|
||
|
||
PQfinish(conn);
|
||
}
|
||
|
||
/**
|
||
* 连接数据库
|
||
* cmd.cols 构造查询
|
||
* cmd.cols 构造更新 (Update的逻辑会在有dekForUpdate的情况下选取dekForUpdate,所以这里无需切换)。
|
||
*/
|
||
void connectionUpdateDek_Update(RotateCommandResult &cmd) {
|
||
// 数据库连接信息
|
||
const char *conninfo = "dbname=postgres user=postgres hostaddr=127.0.0.1 port=5432";
|
||
PGconn *conn = PQconnectdb(conninfo);
|
||
if (PQstatus(conn) != CONNECTION_OK) {
|
||
cmd.errorMessage = "Connection to database failed;\n";
|
||
fprintf(stderr, "Connection to database failed: %s\n", PQerrorMessage(conn));
|
||
PQfinish(conn);
|
||
return;
|
||
}
|
||
|
||
std::string table_name = cmd.tableName;
|
||
|
||
// 开始事务
|
||
PQexec(conn, "BEGIN");
|
||
|
||
// 1. 开启密态功能
|
||
std::string enc_on = "enc on;";
|
||
PGresult *res_enc = PQexec(conn, enc_on.c_str());
|
||
if (PQresultStatus(res_enc) != PGRES_TUPLES_OK) {
|
||
cmd.errorMessage = "Enable encryption failed: " + enc_on + "\n";
|
||
fprintf(stderr, "Enable encryption failed: %s\n", PQerrorMessage(conn));
|
||
PQclear(res_enc);
|
||
PQexec(conn, "ROLLBACK");
|
||
PQfinish(conn);
|
||
return;
|
||
}
|
||
std::cout << enc_on << std::endl;
|
||
PQclear(res_enc);
|
||
|
||
// 2. 获取表的总行数和最大rotate_id
|
||
std::string count_query = "SELECT COUNT(*), MAX(rotate_id) FROM " + table_name + ";";
|
||
PGresult *res_count = PQexec(conn, count_query.c_str());
|
||
|
||
if (PQresultStatus(res_count) != PGRES_TUPLES_OK) {
|
||
cmd.errorMessage = "Query execution failed: " + count_query + "\n";
|
||
fprintf(stderr, "Count query failed: %s\n", PQerrorMessage(conn));
|
||
PQclear(res_count);
|
||
PQfinish(conn);
|
||
return;
|
||
}
|
||
|
||
int total_rows = atoi(PQgetvalue(res_count, 0, 0));
|
||
int max_id = atoi(PQgetvalue(res_count, 0, 1));
|
||
PQclear(res_count);
|
||
|
||
std::cout << "表 " << table_name << " 共有 " << total_rows << " 行数据, 最大ID为 " << max_id << std::endl;
|
||
|
||
// 如果表为空,直接返回
|
||
if (total_rows == 0) {
|
||
std::cout << "表为空,无需更新" << std::endl;
|
||
PQfinish(conn);
|
||
return;
|
||
}
|
||
|
||
// 构建需要查询和更新的列字符串
|
||
std::string columns_str;
|
||
for (auto it = cmd.cols_set.begin(); it != cmd.cols_set.end(); ++it){
|
||
if (it != cmd.cols_set.begin()) {
|
||
columns_str += ", ";
|
||
}
|
||
columns_str += *it;
|
||
}
|
||
|
||
// 批处理大小
|
||
const int BATCH_SIZE = 1000;
|
||
int current_id = 0;
|
||
int processed_rows = 0;
|
||
|
||
// 分批处理数据
|
||
while (processed_rows < total_rows) {
|
||
// 查询一批数据
|
||
std::string select_query = "SELECT rotate_id, " + columns_str +
|
||
" FROM " + table_name +
|
||
" WHERE rotate_id > " + std::to_string(current_id) +
|
||
" ORDER BY rotate_id LIMIT " + std::to_string(BATCH_SIZE) + ";";
|
||
|
||
std::cout << "执行查询: " << select_query << std::endl;
|
||
|
||
PGresult *res_select = PQexec(conn, select_query.c_str());
|
||
if (PQresultStatus(res_select) != PGRES_TUPLES_OK) {
|
||
cmd.errorMessage = "Query execution failed: " + select_query + "\n";
|
||
fprintf(stderr, "SELECT query failed: %s\n", PQerrorMessage(conn));
|
||
PQclear(res_select);
|
||
PQexec(conn, "ROLLBACK");
|
||
PQfinish(conn);
|
||
return;
|
||
}
|
||
|
||
int batch_rows = PQntuples(res_select);
|
||
if (batch_rows == 0) {
|
||
break;
|
||
}
|
||
|
||
std::cout << "本批次处理 " << batch_rows << " 行数据" << std::endl;
|
||
|
||
// 逐行更新数据
|
||
for (int i = 0; i < batch_rows; ++i) {
|
||
int row_id = atoi(PQgetvalue(res_select, i, 0));
|
||
current_id = row_id; // 更新当前处理的ID
|
||
|
||
// 构建更新语句 - 对每一列都使用原值更新,触发密钥轮换
|
||
std::string update_query = "UPDATE " + table_name + " SET ";
|
||
|
||
size_t j = 0;
|
||
for(auto it = cmd.cols_set.begin(); it != cmd.cols_set.end(); ++it,++ j) {
|
||
std::string col_value = PQgetvalue(res_select, i, j+1 ); //
|
||
std::string colname = *it;
|
||
|
||
// 检查是否为NULL
|
||
if (PQgetisnull(res_select, i, j+1 )) {
|
||
update_query += colname + " = NULL";
|
||
} else {
|
||
// 获取列的类型信息
|
||
std::string col_type = cmd.col_type[colname];
|
||
|
||
// 根据PostgreSQL类型OID判断列类型
|
||
bool is_text = (col_type == "text");
|
||
|
||
if (!is_text) {
|
||
// 数值类型不加引号
|
||
update_query += colname + " = " + col_value;
|
||
} else {
|
||
update_query += colname + " = '" + col_value + "'";
|
||
}
|
||
}
|
||
|
||
if (j < cmd.cols_set.size() - 1) {
|
||
update_query += ", ";
|
||
}
|
||
}
|
||
|
||
update_query += " WHERE rotate_id = " + std::to_string(row_id) + ";";
|
||
|
||
// 执行更新
|
||
PGresult *res_update = PQexec(conn, update_query.c_str());
|
||
if (PQresultStatus(res_update) != PGRES_COMMAND_OK) {
|
||
fprintf(stderr, "UPDATE query failed for row %d: %s\n", row_id, PQerrorMessage(conn));
|
||
PQclear(res_update);
|
||
PQclear(res_select);
|
||
PQexec(conn, "ROLLBACK");
|
||
PQfinish(conn);
|
||
return;
|
||
}
|
||
PQclear(res_update);
|
||
}
|
||
|
||
processed_rows += batch_rows;
|
||
std::cout << "已处理 " << processed_rows << " / " << total_rows << " 行数据" << std::endl;
|
||
|
||
PQclear(res_select);
|
||
|
||
// 提交事务
|
||
PGresult *res_commit = PQexec(conn, "COMMIT");
|
||
if (PQresultStatus(res_commit) != PGRES_COMMAND_OK) {
|
||
fprintf(stderr, "COMMIT transaction failed: %s\n", PQerrorMessage(conn));
|
||
PQclear(res_commit);
|
||
PQfinish(conn);
|
||
return;
|
||
}
|
||
PQclear(res_commit);
|
||
}
|
||
|
||
PQfinish(conn);
|
||
}
|
||
|
||
|
||
/**
|
||
* 连接数据库
|
||
* 开始事务
|
||
* 更新数据库中的表级和列级 DEK 记录
|
||
* 提交事务
|
||
* 清理临时资源
|
||
*/
|
||
void connectionUpdateDek_Final(RotateCommandResult &cmd) {
|
||
// 数据库连接信息
|
||
const char *conninfo = "dbname=dekmaster user=dekmaster password=secure_password hostaddr=127.0.0.1 port=5432";
|
||
PGconn *conn = PQconnectdb(conninfo);
|
||
if (PQstatus(conn) != CONNECTION_OK) {
|
||
cmd.errorMessage = "Connection to database failed;\n";
|
||
fprintf(stderr, "Connection to database failed: %s\n", PQerrorMessage(conn));
|
||
PQfinish(conn);
|
||
return;
|
||
}
|
||
|
||
// 开始事务
|
||
PGresult *res_begin = PQexec(conn, "BEGIN");
|
||
if (PQresultStatus(res_begin) != PGRES_COMMAND_OK) {
|
||
fprintf(stderr, "BEGIN transaction failed: %s\n", PQerrorMessage(conn));
|
||
PQclear(res_begin);
|
||
PQfinish(conn);
|
||
return;
|
||
}
|
||
PQclear(res_begin);
|
||
|
||
// 获取待更新的表级密钥
|
||
std::string table_dek_for_update;
|
||
DekInterface::getDekTableLevelForUpdate(table_dek_for_update);
|
||
|
||
// 获取CMK
|
||
|
||
// 如果有表级密钥需要更新
|
||
if (cmd.type != ROTATE_COLUMNS && !table_dek_for_update.empty()) {
|
||
// 加密表级密钥
|
||
KMSInterface::encryptData(table_dek_for_update);
|
||
if (table_dek_for_update.empty()) {
|
||
fprintf(stderr, "Encrypt table DEK failed\n");
|
||
PQexec(conn, "ROLLBACK");
|
||
PQfinish(conn);
|
||
return;
|
||
}
|
||
|
||
// 更新表级密钥
|
||
std::string update_table_dek = "UPDATE dek_store SET dek = '" + table_dek_for_update +
|
||
"' WHERE t = '" + cmd.enc_tableName + "' AND c IS NULL;";
|
||
|
||
PGresult *res_update_table = PQexec(conn, update_table_dek.c_str());
|
||
if (PQresultStatus(res_update_table) != PGRES_COMMAND_OK) {
|
||
fprintf(stderr, "Update table DEK failed: %s\n", PQerrorMessage(conn));
|
||
PQclear(res_update_table);
|
||
PQexec(conn, "ROLLBACK");
|
||
PQfinish(conn);
|
||
return;
|
||
}
|
||
PQclear(res_update_table);
|
||
|
||
std::cout << "表级密钥更新成功" << std::endl;
|
||
}
|
||
|
||
if(!cmd.cols_set.empty() && cmd.type != ROTATE_TABLE){
|
||
// 获取所有待更新的列级密钥
|
||
std::unordered_map<std::string, std::string> column_deks;
|
||
DekInterface::getAllDekColLevelForUpdate(column_deks);
|
||
|
||
// 更新列级密钥
|
||
for(const auto& enc_col_name: cmd.enc_cols_set){
|
||
std::string dek = column_deks[enc_col_name];
|
||
|
||
if (dek.empty()) {
|
||
fprintf(stderr, "Encrypt column DEK failed for column %s\n", enc_col_name.c_str());
|
||
continue;
|
||
}
|
||
|
||
// 加密列级密钥
|
||
// cmk_mapperEncryptDek(dek, cmk);
|
||
KMSInterface::encryptData(dek);
|
||
|
||
// 更新列级密钥
|
||
std::string update_col_dek = "UPDATE dek_store SET dek = '" + dek +
|
||
"' WHERE t = '" + cmd.enc_tableName +
|
||
"' AND c = '" + enc_col_name + "';";
|
||
|
||
PGresult *res_update_col = PQexec(conn, update_col_dek.c_str());
|
||
if (PQresultStatus(res_update_col) != PGRES_COMMAND_OK) {
|
||
fprintf(stderr, "Update column DEK failed for column %s: %s\n",
|
||
enc_col_name.c_str(), PQerrorMessage(conn));
|
||
PQclear(res_update_col);
|
||
continue;
|
||
}
|
||
PQclear(res_update_col);
|
||
|
||
|
||
std::cout << "列 " << enc_col_name << " 密钥更新成功" << std::endl;
|
||
}
|
||
}
|
||
|
||
|
||
// 提交事务
|
||
PGresult *res_commit = PQexec(conn, "COMMIT");
|
||
if (PQresultStatus(res_commit) != PGRES_COMMAND_OK) {
|
||
fprintf(stderr, "COMMIT transaction failed: %s\n", PQerrorMessage(conn));
|
||
PQclear(res_commit);
|
||
PQexec(conn, "ROLLBACK");
|
||
PQfinish(conn);
|
||
return;
|
||
}
|
||
PQclear(res_commit);
|
||
|
||
std::cout << "密钥轮换完成" << std::endl;
|
||
|
||
PQfinish(conn);
|
||
}
|
||
|
||
void connectionInsertTest() {
|
||
// 数据库连接信息
|
||
const char *conninfo1 = "dbname=dekmaster user=dekmaster password=secure_password hostaddr=127.0.0.1 port=5432";
|
||
PGconn *conn1 = PQconnectdb(conninfo1);
|
||
if (PQstatus(conn1) != CONNECTION_OK) {
|
||
fprintf(stderr, "Connection to database failed: %s\n", PQerrorMessage(conn1));
|
||
PQfinish(conn1);
|
||
}
|
||
|
||
std::string table_dek;
|
||
if (KMSInterface::createDek(table_dek, "")) { // 表级密钥使用空列名
|
||
std::string query1 = "insert into dek_store values";
|
||
std::string tuple1 = "(";
|
||
|
||
tuple1.append("'").append(DekInterface::getInfoUser()).append("',")
|
||
.append("'").append(DekInterface::getInfoDb()).append("',")
|
||
.append("'").append(DekInterface::getInfoTable()).append("',")
|
||
.append("NULL,")
|
||
.append("'").append(table_dek).append("')");
|
||
|
||
// 获取列信息并插入列级密钥
|
||
std::string tmp1 = DekInterface::getInfoCol();
|
||
if (!tmp1.empty()) {
|
||
tuple1.append(",");
|
||
}
|
||
query1.append(tuple1);
|
||
|
||
// 处理所有列密钥
|
||
while (!tmp1.empty()) {
|
||
std::string col_dek;
|
||
if(KMSInterface::createDek(col_dek, tmp1)) {
|
||
std::string tuple_col = "(";
|
||
tuple_col.append("'").append(DekInterface::getInfoUser()).append("',")
|
||
.append("'").append(DekInterface::getInfoDb()).append("',")
|
||
.append("'").append(DekInterface::getInfoTable()).append("',")
|
||
.append("'").append(tmp1).append("',")
|
||
.append("'").append(col_dek).append("')");
|
||
|
||
tmp1 = DekInterface::getInfoCol();
|
||
if (!tmp1.empty()) {
|
||
tuple_col.append(",");
|
||
}
|
||
query1.append(tuple_col);
|
||
}
|
||
}
|
||
query1.append(";");
|
||
std::cout << query1 << std::endl;
|
||
const char *query1_c = query1.c_str();
|
||
|
||
// 执行查询
|
||
PGresult *res1 = PQexec(conn1, query1_c);
|
||
// 检查查询状态
|
||
if (PQresultStatus(res1) != PGRES_COMMAND_OK) {
|
||
fprintf(stderr, "Query execution failed: %s\n", PQerrorMessage(conn1));
|
||
PQclear(res1);
|
||
} else {
|
||
PQclear(res1); // 释放查询结果
|
||
}
|
||
}
|
||
PQfinish(conn1); // 关闭数据库连接
|
||
}
|
||
|
||
|
||
void connectionSelectTest() {
|
||
const char *conninfo1 = "dbname=dekmaster user=dekmaster password=secure_password hostaddr=127.0.0.1 port=5432";
|
||
PGconn *conn1 = PQconnectdb(conninfo1);
|
||
if (PQstatus(conn1) != CONNECTION_OK) {
|
||
fprintf(stderr, "Connection to database failed: %s\n", PQerrorMessage(conn1));
|
||
PQfinish(conn1);
|
||
}
|
||
// 执行查询
|
||
std::string query1 = "SELECT * FROM dek_store WHERE t = \'" + DekInterface::getInfoTable() + "\';";
|
||
PGresult *res1 = PQexec(conn1, query1.c_str());
|
||
// 检查查询状态
|
||
if (PQresultStatus(res1) != PGRES_TUPLES_OK) {
|
||
fprintf(stderr, "Query execution failed: %s\n", PQerrorMessage(conn1));
|
||
PQclear(res1);
|
||
PQfinish(conn1);
|
||
}
|
||
int nrows = PQntuples(res1); // 获取查询结果的行数
|
||
int ncols = PQnfields(res1); // 获取查询结果的列数
|
||
for (int i = 0; i < nrows; i++) { // 遍历每一行
|
||
for (int j = 0; j < ncols; j++) { // 遍历每一列
|
||
char *value = PQgetvalue(res1, i, j); // 获取第 i 行,第 j 列的值
|
||
std::cout << value << "\t";
|
||
}
|
||
std::cout << std::endl; // 换行表示下一行
|
||
}
|
||
PQclear(res1); // 释放查询结果
|
||
PQfinish(conn1); // 关闭数据库连接
|
||
}
|
||
|
||
void connectionSelect() {
|
||
const char *conninfo1 = "dbname=dekmaster user=dekmaster password=secure_password hostaddr=127.0.0.1 port=5432";
|
||
PGconn *conn1 = PQconnectdb(conninfo1);
|
||
if (PQstatus(conn1) != CONNECTION_OK) {
|
||
fprintf(stderr, "Connection to database failed: %s\n", PQerrorMessage(conn1));
|
||
PQfinish(conn1);
|
||
return;
|
||
}
|
||
|
||
// 执行查询
|
||
std::string table_name = DekInterface::getInfoTable();
|
||
|
||
std::string query1 = "SELECT c, dek FROM dek_store WHERE t = '" + table_name + "';";
|
||
PGresult *res1 = PQexec(conn1, query1.c_str());
|
||
|
||
// 检查查询状态
|
||
if (PQresultStatus(res1) != PGRES_TUPLES_OK) {
|
||
fprintf(stderr, "Query execution failed: %s\n", PQerrorMessage(conn1));
|
||
PQclear(res1);
|
||
PQfinish(conn1);
|
||
return;
|
||
}
|
||
|
||
// 获取查询结果的行数
|
||
int nrows = PQntuples(res1);
|
||
|
||
// 获取CMK
|
||
|
||
for (int i = 0; i < nrows; i++) { // 遍历每一行
|
||
std::string col_name(PQgetvalue(res1, i, 0)); // 列名
|
||
std::string dek(PQgetvalue(res1, i, 1)); // 对应的密钥
|
||
|
||
// 解密密钥
|
||
KMSInterface::decryptData(dek);
|
||
|
||
if (PQgetisnull(res1, i, 0)) { // 如果列名是NULL,表示这是表级密钥
|
||
DekInterface::setDekTableLevel(dek);
|
||
} else { // 否则是列级密钥
|
||
DekInterface::setDekColLevel(col_name, dek);
|
||
}
|
||
}
|
||
|
||
PQclear(res1); // 释放查询结果
|
||
PQfinish(conn1); // 关闭数据库连接
|
||
}
|
||
|
||
// 读取 JSON 文件
|
||
json read_json_from_file(const std::string& file_path) {
|
||
std::ifstream file(file_path);
|
||
if (!file.is_open()) {
|
||
std::cerr << "无法打开文件:" << file_path << std::endl;
|
||
exit(1);
|
||
}
|
||
|
||
json j;
|
||
file >> j;
|
||
return j;
|
||
}
|
||
|
||
// 获取表中所有列
|
||
std::vector<std::string> get_columns(const std::string& table_name, const json& j) {
|
||
std::vector<std::string> columns;
|
||
try {
|
||
// 获取指定表名的列信息
|
||
auto columns_info = j["schema"]["public"][table_name]["columns"];
|
||
for (const auto& col : columns_info) {
|
||
columns.push_back(col);
|
||
}
|
||
} catch (const std::exception& e) {
|
||
std::cerr << "获取列时发生错误:" << e.what() << std::endl;
|
||
}
|
||
return columns;
|
||
}
|
||
|
||
// 获取表中所有列
|
||
std::set<std::string> get_all_columns(const std::string& table_name, const json& j) {
|
||
std::set<std::string> cols;
|
||
try {
|
||
// 定位到 schema.public.{table_name}.columns
|
||
auto cols_array = j.at("schema").at("public").at(table_name).at("columns");
|
||
for (const auto& col : cols_array) {
|
||
cols.insert(col.get<std::string>());
|
||
}
|
||
} catch (const std::exception& e) {
|
||
std::cerr << "获取所有列名时发生错误:" << e.what() << std::endl;
|
||
}
|
||
return cols;
|
||
}
|
||
|
||
// 获取所有密文列名 ↔ 明文列名 之间的映射,存储在一个 std::map 中
|
||
std::unordered_map<std::string, std::string> get_column_mapping(const std::string& table_name, const json& j) {
|
||
std::unordered_map<std::string, std::string> mapping;
|
||
try {
|
||
// 定位到 schema.public.{table_name}.map
|
||
auto map_info = j.at("schema").at("public").at(table_name).at("map");
|
||
for (auto it = map_info.begin(); it != map_info.end(); ++it) {
|
||
// key: 列名1(可能是加密后的列名或明文标识)
|
||
// value: 列名2(对应的另一端)
|
||
mapping[it.key()] = it.value().get<std::string>();
|
||
}
|
||
} catch (const std::exception& e) {
|
||
std::cerr << "获取列映射时发生错误:" << e.what() << std::endl;
|
||
}
|
||
return mapping;
|
||
}
|
||
|
||
// 获取列的数据类型
|
||
std::string get_column_type(const std::string& table_name, const std::string& column_name, const json& j) {
|
||
try {
|
||
// 获取表的信息
|
||
auto table_info = j["schema"]["public"][table_name];
|
||
|
||
// 检查列是否存在
|
||
if (table_info.contains(column_name)) {
|
||
// 返回列的类型
|
||
if (table_info[column_name].contains("type")) {
|
||
return table_info[column_name]["type"];
|
||
}
|
||
}
|
||
} catch (const std::exception& e) {
|
||
std::cerr << "获取列类型时发生错误:" << e.what() << std::endl;
|
||
}
|
||
return "";
|
||
}
|
||
|
||
/**
|
||
* 执行数据库操作的辅助函数
|
||
* @param sql SQL语句
|
||
* @param errorMsg 错误描述(用于日志)
|
||
* @return 0表示成功,-1表示失败
|
||
*/
|
||
int executeSQL(const char *sql, const char *errorMsg) {
|
||
PGconn *conn = PQconnectdb(DB_CONNINFO);
|
||
PGresult *res = NULL;
|
||
int result = -1;
|
||
|
||
if (PQstatus(conn) != CONNECTION_OK) {
|
||
fprintf(stderr, "Connection to database failed: %s\n", PQerrorMessage(conn));
|
||
PQfinish(conn);
|
||
return -1;
|
||
}
|
||
|
||
res = PQexec(conn, sql);
|
||
|
||
if (PQresultStatus(res) == PGRES_COMMAND_OK) {
|
||
result = 0;
|
||
printf("Database operation succeeded: %s\n", errorMsg);
|
||
} else {
|
||
fprintf(stderr, "Database operation failed (%s): %s\n", errorMsg, PQerrorMessage(conn));
|
||
fprintf(stderr, "SQL: %s\n", sql);
|
||
}
|
||
|
||
PQclear(res);
|
||
PQfinish(conn);
|
||
return result;
|
||
}
|
||
|
||
/**
|
||
* 检查表是否存在
|
||
* @param tableName 表名
|
||
* @return 1表示存在,0表示不存在,-1表示查询失败
|
||
*/
|
||
int tableExists(const char *tableName) {
|
||
PGconn *conn = PQconnectdb(DB_CONNINFO);
|
||
PGresult *res = NULL;
|
||
int result = -1;
|
||
char query[256];
|
||
|
||
if (PQstatus(conn) != CONNECTION_OK) {
|
||
fprintf(stderr, "Connection to database failed: %s\n", PQerrorMessage(conn));
|
||
PQfinish(conn);
|
||
return -1;
|
||
}
|
||
|
||
snprintf(query, sizeof(query),
|
||
"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = '%s')",
|
||
tableName);
|
||
|
||
res = PQexec(conn, query);
|
||
|
||
if (PQresultStatus(res) == PGRES_TUPLES_OK && PQntuples(res) > 0) {
|
||
const char *exists = PQgetvalue(res, 0, 0);
|
||
result = (strcmp(exists, "t") == 0) ? 1 : 0;
|
||
} else {
|
||
fprintf(stderr, "Failed to check table existence: %s\n", PQerrorMessage(conn));
|
||
}
|
||
|
||
PQclear(res);
|
||
PQfinish(conn);
|
||
return result;
|
||
}
|
||
|
||
/**
|
||
* 备份dek_store表
|
||
* @return 0表示成功,-1表示失败
|
||
*/
|
||
int backupDekStore() {
|
||
char sql[512];
|
||
int result = 0;
|
||
|
||
// 检查原表是否存在
|
||
int originalExists = tableExists(ORIGINAL_TABLE_NAME);
|
||
if (originalExists < 0) {
|
||
fprintf(stderr, "Failed to check if original table exists\n");
|
||
return -1;
|
||
}
|
||
|
||
if (originalExists == 0) {
|
||
printf("Original table %s does not exist, skipping database backup\n", ORIGINAL_TABLE_NAME);
|
||
return 0;
|
||
}
|
||
|
||
// 删除可能存在的备份表
|
||
snprintf(sql, sizeof(sql), "DROP TABLE IF EXISTS %s", BACKUP_TABLE_NAME);
|
||
if (executeSQL(sql, "drop existing backup table") != 0) {
|
||
return -1;
|
||
}
|
||
|
||
// 创建备份表并复制数据
|
||
snprintf(sql, sizeof(sql),
|
||
"CREATE TABLE %s AS SELECT * FROM %s",
|
||
BACKUP_TABLE_NAME, ORIGINAL_TABLE_NAME);
|
||
if (executeSQL(sql, "create backup table") != 0) {
|
||
return -1;
|
||
}
|
||
|
||
printf("Database table %s backed up successfully to %s\n", ORIGINAL_TABLE_NAME, BACKUP_TABLE_NAME);
|
||
return 0;
|
||
}
|
||
|
||
/**
|
||
* 删除dek_store备份表
|
||
* @return 0表示成功,-1表示失败
|
||
*/
|
||
int deleteDekStoreBackup() {
|
||
char sql[256];
|
||
|
||
// 检查备份表是否存在
|
||
int backupExists = tableExists(BACKUP_TABLE_NAME);
|
||
if (backupExists < 0) {
|
||
fprintf(stderr, "Failed to check if backup table exists\n");
|
||
return -1;
|
||
}
|
||
|
||
if (backupExists == 0) {
|
||
printf("Backup table %s does not exist\n", BACKUP_TABLE_NAME);
|
||
return 0;
|
||
}
|
||
|
||
// 删除备份表
|
||
snprintf(sql, sizeof(sql), "DROP TABLE %s", BACKUP_TABLE_NAME);
|
||
if (executeSQL(sql, "delete backup table") != 0) {
|
||
return -1;
|
||
}
|
||
|
||
printf("Database backup table %s deleted successfully\n", BACKUP_TABLE_NAME);
|
||
return 0;
|
||
}
|
||
|
||
/**
|
||
* 从备份恢复dek_store表
|
||
* @return 0表示成功,-1表示失败
|
||
*/
|
||
int restoreDekStore() {
|
||
char sql[512];
|
||
|
||
// 检查备份表是否存在
|
||
int backupExists = tableExists(BACKUP_TABLE_NAME);
|
||
if (backupExists < 0) {
|
||
fprintf(stderr, "Failed to check if backup table exists\n");
|
||
return -1;
|
||
}
|
||
|
||
if (backupExists == 0) {
|
||
printf("No backup table %s found for rollback\n", BACKUP_TABLE_NAME);
|
||
return 0;
|
||
}
|
||
|
||
// 删除当前表(如果存在)
|
||
snprintf(sql, sizeof(sql), "DROP TABLE IF EXISTS %s", ORIGINAL_TABLE_NAME);
|
||
if (executeSQL(sql, "drop current table for restore") != 0) {
|
||
return -1;
|
||
}
|
||
|
||
// 重命名备份表为原表名
|
||
snprintf(sql, sizeof(sql),
|
||
"ALTER TABLE %s RENAME TO %s",
|
||
BACKUP_TABLE_NAME, ORIGINAL_TABLE_NAME);
|
||
if (executeSQL(sql, "restore table from backup") != 0) {
|
||
return -1;
|
||
}
|
||
|
||
printf("Database table %s restored from backup successfully\n", ORIGINAL_TABLE_NAME);
|
||
return 0;
|
||
}
|
||
} |