Files
ldb/kvs_protocol_resp.c

552 lines
17 KiB
C
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#include "kvs_protocol_resp.h"
#include "kvs_rw_tools.h"
#if ENABLE_ARRAY
extern kvs_array_t global_array;
#endif
#if ENABLE_RBTREE
extern kvs_rbtree_t global_rbtree;
#endif
#if ENABLE_HASH
extern kvs_hash_t global_hash;
#endif
static int need(const uint8_t *p, const uint8_t *end, size_t n) {
return (p + n <= end) ? 0 : -1;
}
/* find "\r\n", return \r */
static int find_crlf(const uint8_t *p, const uint8_t *end, const uint8_t **line_end) {
const uint8_t *q = p;
size_t lim = (size_t)(end - p);
if (lim > (size_t)RESP_MAX_LINE) lim = (size_t)RESP_MAX_LINE;
const uint8_t *stop = p + lim;
while (q + 1 < stop) {
if (q[0] == '\r' && q[1] == '\n') {
*line_end = q;
return 0;
}
q++;
}
return -1;
}
/* 解析有符号整数 [p, line_end) */
static int parse_i64(const uint8_t *p, const uint8_t *line_end, int64_t *out) {
if (!p || !line_end || p >= line_end) return -1;
int neg = 0;
if (*p == '-') {
neg = 1;
p++;
if (p >= line_end) return -1;
}
int64_t x = 0;
for (const uint8_t *q = p; q < line_end; q++) {
if (*q < '0' || *q > '9') return -1;
int digit = (int)(*q - '0');
/* overflow-safe-ish for typical Redis sizes */
if (x > (INT64_MAX - digit) / 10) return -1;
x = x * 10 + digit;
}
*out = neg ? -x : x;
return 0;
}
// 字符串比对
int ascii_casecmp(const uint8_t *a, uint32_t alen, const char *b) {
size_t blen = strlen(b);
if (alen != (uint32_t)blen) return -1;
for (uint32_t i = 0; i < alen; i++) {
uint8_t ca = a[i];
uint8_t cb = (uint8_t)b[i];
if (ca >= 'A' && ca <= 'Z') ca = (uint8_t)(ca + 32);
if (cb >= 'A' && cb <= 'Z') cb = (uint8_t)(cb + 32);
if (ca != cb) return -1;
}
return 0;
}
/* ----------------- RESP2 parser (one command) ----------------- */
/* 解析批量字符串:$<长度>\r\n<字节>\r\n
* 成功时,将切片指向 <字节> 部分,并前进 *pp。
* 返回值:-1 错误0 需要更多数据1 成功
*/
static int parse_bulk(const uint8_t **pp, const uint8_t *end, resp_slice_t *out) {
const uint8_t *p = *pp;
if (need(p, end, 1) < 0) return 0;
if (*p != '$') return -1;
p++;
const uint8_t *le = NULL;
// 寻找 <长度> 的末尾
if (find_crlf(p, end, &le) < 0) return 0; /* need more */
int64_t n64 = 0;
// 解析 <长度>
if (parse_i64(p, le, &n64) < 0) return -1;
p = le + 2; /* 跳过 CRLF */
if (n64 < 0) {
/* nil bulk ($-1) 无效. */
return -1;
}
/* 大于 RESP_MAX_BULK 也无效 */
if ((uint64_t)n64 > (uint64_t)RESP_MAX_BULK) return -1;
uint32_t n = (uint32_t)n64;
if (need(p, end, (size_t)n + 2) < 0) return 0; /* need more */
// byte = <字节> 的首位
const uint8_t *bytes = p;
p += n;
// <字节>+n 不是\r\n 解析错误
if (p[0] != '\r' || p[1] != '\n') return -1;
// 跳过\r\n
p += 2;
out->ptr = bytes;
out->len = n;
// 移动指针
*pp = p;
return 1;
}
/* 解析数组头: *<n>\r\n */
static int parse_array_len(const uint8_t **pp, const uint8_t *end, int64_t *out_n) {
const uint8_t *p = *pp;
if (need(p, end, 1) < 0) return 0;
if (*p != '*') return -1;
p++;
const uint8_t *le = NULL;
if (find_crlf(p, end, &le) < 0) return 0;
int64_t n64 = 0;
if (parse_i64(p, le, &n64) < 0) return -1;
p = le + 2;
*pp = p;
*out_n = n64;
return 1;
}
/* 解析简单命令 */
static int parse_inline(const uint8_t *buf, int len, resp_cmd_t *out_cmd) {
const uint8_t *p = buf;
const uint8_t *end = buf + (size_t)len;
const uint8_t *le = NULL;
if (find_crlf(p, end, &le) < 0) return 0; /* need more */
/* split [p, le) by spaces/tabs */
out_cmd->argc = 0;
const uint8_t *s = p;
while (s < le) {
while (s < le && (*s == ' ' || *s == '\t')) s++;
if (s >= le) break;
const uint8_t *t = s;
while (t < le && *t != ' ' && *t != '\t') t++;
if (out_cmd->argc >= RESP_MAX_ARGC) return -1;
out_cmd->argv[out_cmd->argc].ptr = s;
out_cmd->argv[out_cmd->argc].len = (uint32_t)(t - s);
out_cmd->argc++;
s = t;
}
if (out_cmd->argc == 0) return -1;
return (int)((le + 2) - buf);
}
int resp_parse_one_cmd(const uint8_t *buf, int len, resp_cmd_t *out_cmd) {
if (!buf || len <= 0 || !out_cmd) return -1;
memset(out_cmd, 0, sizeof(*out_cmd));
const uint8_t *p = buf;
const uint8_t *end = buf + (size_t)len;
if (need(p, end, 1) < 0) return 0;
if (*p != '*') {
/* inline */
return parse_inline(buf, len, out_cmd);
}
/* multi bulk */
int64_t n64 = 0;
int r = parse_array_len(&p, end, &n64);
if (r == 0) return 0;
if (r < 0) return -1;
if (n64 <= 0 || n64 > (int64_t)RESP_MAX_ARGC) return -1;
out_cmd->argc = (uint32_t)n64;
/* scan + parse each bulk string */
for (uint32_t i = 0; i < out_cmd->argc; i++) {
resp_slice_t sl = {0};
int rr = parse_bulk(&p, end, &sl);
if (rr == 0) return 0;
if (rr < 0) return -1;
out_cmd->argv[i] = sl;
}
return (int)(p - buf);
}
/* ----------------- RESP2 builder ----------------- */
static int write_bytes(uint8_t **pp, const uint8_t *end, const void *src, size_t n) {
uint8_t *p = *pp;
if ((size_t)(end - p) < n) return -1;
memcpy(p, src, n);
p += n;
*pp = p;
return 0;
}
static int write_crlf(uint8_t **pp, const uint8_t *end) {
return write_bytes(pp, end, "\r\n", 2);
}
static int write_i64_ascii(uint8_t **pp, const uint8_t *end, int64_t x) {
char tmp[64];
int n = snprintf(tmp, sizeof(tmp), "%lld", (long long)x);
if (n <= 0) return -1;
return write_bytes(pp, end, tmp, (size_t)n);
}
static int write_u32_ascii(uint8_t **pp, const uint8_t *end, uint32_t x) {
char tmp[32];
int n = snprintf(tmp, sizeof(tmp), "%u", (unsigned)x);
if (n <= 0) return -1;
return write_bytes(pp, end, tmp, (size_t)n);
}
int resp_build_value(const resp_value_t *v, uint8_t *out, size_t cap) {
if (!v || !out || cap == 0) return -1;
uint8_t *p = out;
const uint8_t *end = out + cap;
switch (v->type) {
case RESP_T_SIMPLE_STR:
if (write_bytes(&p, end, "+", 1) < 0) return -1;
if (write_bytes(&p, end, v->bulk.ptr, v->bulk.len) < 0) return -1;
if (write_crlf(&p, end) < 0) return -1;
break;
case RESP_T_ERROR:
if (write_bytes(&p, end, "-", 1) < 0) return -1;
if (write_bytes(&p, end, v->bulk.ptr, v->bulk.len) < 0) return -1;
if (write_crlf(&p, end) < 0) return -1;
break;
case RESP_T_INTEGER:
if (write_bytes(&p, end, ":", 1) < 0) return -1;
if (write_i64_ascii(&p, end, v->i64) < 0) return -1;
if (write_crlf(&p, end) < 0) return -1;
break;
case RESP_T_NIL:
if (write_bytes(&p, end, "$-1\r\n", 5) < 0) return -1;
break;
case RESP_T_BULK_STR:
if (write_bytes(&p, end, "$", 1) < 0) return -1;
if (write_u32_ascii(&p, end, v->bulk.len) < 0) return -1;
if (write_crlf(&p, end) < 0) return -1;
if (v->bulk.len > 0 && v->bulk.ptr) {
if (write_bytes(&p, end, v->bulk.ptr, v->bulk.len) < 0) return -1;
}
if (write_crlf(&p, end) < 0) return -1;
break;
default:
return -1;
}
return (int)(p - out);
}
/* helpers */
resp_value_t resp_simple(const char *s) {
resp_value_t v;
v.type = RESP_T_SIMPLE_STR;
v.i64 = 0;
v.bulk.ptr = (const uint8_t*)s;
v.bulk.len = (uint32_t)strlen(s);
return v;
}
resp_value_t resp_error(const char *s) {
resp_value_t v;
v.type = RESP_T_ERROR;
v.i64 = 0;
v.bulk.ptr = (const uint8_t*)s;
v.bulk.len = (uint32_t)strlen(s);
return v;
}
resp_value_t resp_int(int64_t x) {
resp_value_t v;
v.type = RESP_T_INTEGER;
v.i64 = x;
v.bulk.ptr = NULL;
v.bulk.len = 0;
return v;
}
resp_value_t resp_bulk(const uint8_t *p, uint32_t n) {
resp_value_t v;
v.type = RESP_T_BULK_STR;
v.i64 = 0;
v.bulk.ptr = p;
v.bulk.len = n;
return v;
}
resp_value_t resp_nil(void) {
resp_value_t v;
v.type = RESP_T_NIL;
v.i64 = 0;
v.bulk.ptr = NULL;
v.bulk.len = 0;
return v;
}
/* ----------------- dispatcher (minimal) ----------------- */
static int expect_argv(const resp_cmd_t *cmd, uint32_t n) {
return (cmd && cmd->argc == n) ? 0 : -1;
}
const char *command[] = {
"SET", "GET", "DEL", "MOD", "EXIST",
"RSET", "RGET", "RDEL", "RMOD", "REXIST",
"HSET", "HGET", "HDEL", "HMOD", "HEXIST",
"SAVE", "PSYNC"
};
/**
* 输入cmd
* 输出out_value
* 返回:-1 失败参数错误0 成功
*/
int resp_dispatch(const resp_cmd_t *cmd, resp_value_t *out_value) {
if (!cmd || !out_value) return -1;
if (cmd->argc == 0 || cmd->argv[0].ptr == NULL) {
*out_value = resp_error("ERR empty command");
return 0;
}
const uint8_t *cptr = cmd->argv[0].ptr;
uint32_t clen = cmd->argv[0].len;
kvs_cmd_t op = KVS_CMD_COUNT;
for(kvs_cmd_t i = KVS_CMD_START; i < KVS_CMD_COUNT; ++ i){
if(ascii_casecmp(cptr, clen, command[i]) == 0){
op = i;
break;
}
}
if (op == KVS_CMD_COUNT) {
*out_value = resp_error("ERR unknown command");
return -1;
}
switch (op) {
#if ENABLE_ARRAY
case KVS_CMD_SET: {
if (cmd->argc != 3) { *out_value = resp_error("ERR wrong number of arguments for 'set'"); return -1; }
// <0 error; 0 success; 1 exist
int r = kvs_array_set_bin(&global_array,
cmd->argv[1].ptr, cmd->argv[1].len,
cmd->argv[2].ptr, cmd->argv[2].len);
if (r < 0) { *out_value = resp_error("ERR internal error"); return 0; }
else if (r == 1) { *out_value = resp_error("ERR key has exist"); return 0; }
*out_value = resp_simple("OK");
return 0;
}
case KVS_CMD_GET: {
if (cmd->argc != 2) { *out_value = resp_error("ERR wrong number of arguments for 'get'"); return -1; }
uint32_t vlen = 0;
// NULL not exist, NOTNULL exist
const char *v = kvs_array_get_bin(&global_array, cmd->argv[1].ptr, cmd->argv[1].len, &vlen);
if (!v) { *out_value = resp_nil(); return 0; }
*out_value = resp_bulk((const uint8_t*)v, vlen);
return 0;
}
case KVS_CMD_DEL: {
if (cmd->argc != 2) { *out_value = resp_error("ERR wrong number of arguments for 'del'"); return -1; }
// <0 error; =0 success; >0 no exist
int r = kvs_array_del_bin(&global_array, cmd->argv[1].ptr, cmd->argv[1].len);
if (r < 0) { *out_value = resp_error("ERR internal error"); return 0; }
// r == 0, del 1; r > 0, del 0.
*out_value = resp_int((r == 0) ? 1 : 0);
return 0;
}
case KVS_CMD_MOD: {
if (cmd->argc != 3) { *out_value = resp_error("ERR wrong number of arguments for 'mod'"); return -1; }
// <0 error; =0 success; >0 no exist
int r = kvs_array_mod_bin(&global_array,
cmd->argv[1].ptr, cmd->argv[1].len,
cmd->argv[2].ptr, cmd->argv[2].len);
if (r < 0) { *out_value = resp_error("ERR internal error"); return 0; }
if (r == 0) *out_value = resp_simple("OK");
else *out_value = resp_error("ERR no such key");
return 0;
}
case KVS_CMD_EXIST: {
if (cmd->argc != 2) { *out_value = resp_error("ERR wrong number of arguments for 'exist'"); return -1; }
// =0 exist, =1 no exist
int r = kvs_array_exist_bin(&global_array, cmd->argv[1].ptr, cmd->argv[1].len);
if (r < 0) { *out_value = resp_error("ERR internal error"); return 0; }
*out_value = resp_int((r == 0) ? 1 : 0);
return 0;
}
#endif
#if ENABLE_RBTREE
case KVS_CMD_RSET: {
if (cmd->argc != 3) { *out_value = resp_error("ERR wrong number of arguments for 'rset'"); return 0; }
// <0 error; 0 success; 1 exist
int r = kvs_rbtree_set(&global_rbtree,
cmd->argv[1].ptr, cmd->argv[1].len,
cmd->argv[2].ptr, cmd->argv[2].len);
if (r < 0) { *out_value = resp_error("ERR internal error"); return 0; }
else if (r == 1) { *out_value = resp_error("ERR key has exist"); return 0; }
*out_value = resp_simple("OK");
return 0;
}
case KVS_CMD_RGET: {
if (cmd->argc != 2) { *out_value = resp_error("ERR wrong number of arguments for 'rget'"); return 0; }
uint32_t vlen = 0;
// NULL notexist, NOTNULL exist。out_value_len 是长度。
const char *v = kvs_rbtree_get(&global_rbtree, cmd->argv[1].ptr, cmd->argv[1].len, &vlen);
if (!v) { *out_value = resp_nil(); return 0; }
*out_value = resp_bulk((const uint8_t*)v, vlen);
return 0;
}
case KVS_CMD_RDEL: {
if (cmd->argc != 2) { *out_value = resp_error("ERR wrong number of arguments for 'rdel'"); return 0; }
// <0 error; =0 success; >0 no exist
int r = kvs_rbtree_del(&global_rbtree, cmd->argv[1].ptr, cmd->argv[1].len);
if (r < 0) { *out_value = resp_error("ERR internal error"); return 0; }
*out_value = resp_int((r == 0) ? 1 : 0);
return 0;
}
case KVS_CMD_RMOD: {
if (cmd->argc != 3) { *out_value = resp_error("ERR wrong number of arguments for 'rmod'"); return 0; }
// < 0 error; =0 success; >0 no exist
int r = kvs_rbtree_mod(&global_rbtree,
cmd->argv[1].ptr, cmd->argv[1].len,
cmd->argv[2].ptr, cmd->argv[2].len);
if (r < 0) { *out_value = resp_error("ERR internal error"); return 0; }
if (r == 0) *out_value = resp_simple("OK");
else *out_value = resp_error("ERR no such key");
return 0;
}
case KVS_CMD_REXIST: {
if (cmd->argc != 2) { *out_value = resp_error("ERR wrong number of arguments for 'rexist'"); return 0; }
// =0 exist, =1 no exist
int r = kvs_rbtree_exist(&global_rbtree, cmd->argv[1].ptr, cmd->argv[1].len);
if (r < 0) { *out_value = resp_error("ERR internal error"); return 0; }
*out_value = resp_int((r == 0) ? 1 : 0);
return 0;
}
#endif
#if ENABLE_HASH
case KVS_CMD_HSET: {
if (cmd->argc != 3) { *out_value = resp_error("ERR wrong number of arguments for 'hset'"); return 0; }
// <0 error; 0 success; 1 exist
int r = kvs_hash_set_bin(&global_hash,
cmd->argv[1].ptr, cmd->argv[1].len,
cmd->argv[2].ptr, cmd->argv[2].len);
if (r < 0) { *out_value = resp_error("ERR internal error"); return 0; }
else if (r == 1) { *out_value = resp_error("ERR key has exist"); return 0; }
*out_value = resp_simple("OK");
return 0;
}
case KVS_CMD_HGET: {
if (cmd->argc != 2) { *out_value = resp_error("ERR wrong number of arguments for 'hget'"); return 0; }
uint32_t vlen = 0;
// NULL notexist, NOTNULL exist。out_value_len 是长度。
const char *v = kvs_hash_get_bin(&global_hash, cmd->argv[1].ptr, cmd->argv[1].len, &vlen);
if (!v) { *out_value = resp_nil(); return 0; }
*out_value = resp_bulk((const uint8_t*)v, vlen);
return 0;
}
case KVS_CMD_HDEL: {
if (cmd->argc != 2) { *out_value = resp_error("ERR wrong number of arguments for 'hdel'"); return 0; }
// <0 error; =0 success; >0 no exist
int r = kvs_hash_del_bin(&global_hash, cmd->argv[1].ptr, cmd->argv[1].len);
if (r < 0) { *out_value = resp_error("ERR internal error"); return 0; }
*out_value = resp_int((r == 0) ? 1 : 0);
return 0;
}
case KVS_CMD_HMOD: {
if (cmd->argc != 3) { *out_value = resp_error("ERR wrong number of arguments for 'hmod'"); return 0; }
// <0 error; =0 success; >0 no exist
int r = kvs_hash_mod_bin(&global_hash,
cmd->argv[1].ptr, cmd->argv[1].len,
cmd->argv[2].ptr, cmd->argv[2].len);
if (r < 0) { *out_value = resp_error("ERR internal error"); return 0; }
if (r == 0) *out_value = resp_simple("OK");
else *out_value = resp_error("ERR no such key");
return 0;
}
case KVS_CMD_HEXIST: {
if (cmd->argc != 2) { *out_value = resp_error("ERR wrong number of arguments for 'hexist'"); return 0; }
// =0 exist, =1 no exist
int r = kvs_hash_exist_bin(&global_hash, cmd->argv[1].ptr, cmd->argv[1].len);
if (r < 0) { *out_value = resp_error("ERR internal error"); return 0; }
*out_value = resp_int((r == 0) ? 1 : 0);
return 0;
}
#endif
/* ---------------- misc ---------------- */
case KVS_CMD_SAVE: {
if (cmd->argc != 1) { *out_value = resp_error("ERR wrong number of arguments for 'save'"); return 0; }
int r = kvs_save_to_file();
if (r < 0) { *out_value = resp_error("ERR save failed"); return 0; }
*out_value = resp_simple("OK");
return 0;
}
case KVS_CMD_PSYNC:
*out_value = resp_simple("OK");
return 0;
default:
break;
}
*out_value = resp_error("ERR unknown command");
return 0;
}