#include "kvs_protocol_resp.h" #if ENABLE_ARRAY extern kvs_array_t global_array; #endif #if ENABLE_RBTREE extern kvs_rbtree_t global_rbtree; #endif #if ENABLE_HASH extern kvs_hash_t global_hash; #endif static int need(const uint8_t *p, const uint8_t *end, size_t n) { return (p + n <= end) ? 0 : -1; } /* find "\r\n", return \r */ static int find_crlf(const uint8_t *p, const uint8_t *end, const uint8_t **line_end) { const uint8_t *q = p; size_t lim = (size_t)(end - p); if (lim > (size_t)RESP_MAX_LINE) lim = (size_t)RESP_MAX_LINE; const uint8_t *stop = p + lim; while (q + 1 < stop) { if (q[0] == '\r' && q[1] == '\n') { *line_end = q; return 0; } q++; } return -1; } /* 解析有符号整数 [p, line_end) */ static int parse_i64(const uint8_t *p, const uint8_t *line_end, int64_t *out) { if (!p || !line_end || p >= line_end) return -1; int neg = 0; if (*p == '-') { neg = 1; p++; if (p >= line_end) return -1; } int64_t x = 0; for (const uint8_t *q = p; q < line_end; q++) { if (*q < '0' || *q > '9') return -1; int digit = (int)(*q - '0'); /* overflow-safe-ish for typical Redis sizes */ if (x > (INT64_MAX - digit) / 10) return -1; x = x * 10 + digit; } *out = neg ? -x : x; return 0; } // 字符串比对 static int ascii_casecmp(const uint8_t *a, uint32_t alen, const char *b) { size_t blen = strlen(b); if (alen != (uint32_t)blen) return -1; for (uint32_t i = 0; i < alen; i++) { uint8_t ca = a[i]; uint8_t cb = (uint8_t)b[i]; if (ca >= 'A' && ca <= 'Z') ca = (uint8_t)(ca + 32); if (cb >= 'A' && cb <= 'Z') cb = (uint8_t)(cb + 32); if (ca != cb) return -1; } return 0; } /* ----------------- RESP2 parser (one command) ----------------- */ /* 解析批量字符串:$<长度>\r\n<字节>\r\n * 成功时,将切片指向 <字节> 部分,并前进 *pp。 * 返回值:-1 错误,0 需要更多数据,1 成功 */ static int parse_bulk(const uint8_t **pp, const uint8_t *end, resp_slice_t *out) { const uint8_t *p = *pp; if (need(p, end, 1) < 0) return 0; if (*p != '$') return -1; p++; const uint8_t *le = NULL; // 寻找 <长度> 的末尾 if (find_crlf(p, end, &le) < 0) return 0; /* need more */ int64_t n64 = 0; // 解析 <长度> if (parse_i64(p, le, &n64) < 0) return -1; p = le + 2; /* 跳过 CRLF */ if (n64 < 0) { /* nil bulk ($-1) 无效. */ return -1; } /* 大于 RESP_MAX_BULK 也无效 */ if ((uint64_t)n64 > (uint64_t)RESP_MAX_BULK) return -1; uint32_t n = (uint32_t)n64; if (need(p, end, (size_t)n + 2) < 0) return 0; /* need more */ // byte = <字节> 的首位 const uint8_t *bytes = p; p += n; // <字节>+n 不是\r\n 解析错误 if (p[0] != '\r' || p[1] != '\n') return -1; // 跳过\r\n p += 2; out->ptr = bytes; out->len = n; // 移动指针 *pp = p; return 1; } /* 解析数组头: *\r\n */ static int parse_array_len(const uint8_t **pp, const uint8_t *end, int64_t *out_n) { const uint8_t *p = *pp; if (need(p, end, 1) < 0) return 0; if (*p != '*') return -1; p++; const uint8_t *le = NULL; if (find_crlf(p, end, &le) < 0) return 0; int64_t n64 = 0; if (parse_i64(p, le, &n64) < 0) return -1; p = le + 2; *pp = p; *out_n = n64; return 1; } /* 解析简单命令 */ static int parse_inline(const uint8_t *buf, int len, resp_cmd_t *out_cmd) { const uint8_t *p = buf; const uint8_t *end = buf + (size_t)len; const uint8_t *le = NULL; if (find_crlf(p, end, &le) < 0) return 0; /* need more */ /* split [p, le) by spaces/tabs */ out_cmd->argc = 0; const uint8_t *s = p; while (s < le) { while (s < le && (*s == ' ' || *s == '\t')) s++; if (s >= le) break; const uint8_t *t = s; while (t < le && *t != ' ' && *t != '\t') t++; if (out_cmd->argc >= RESP_MAX_ARGC) return -1; out_cmd->argv[out_cmd->argc].ptr = s; out_cmd->argv[out_cmd->argc].len = (uint32_t)(t - s); out_cmd->argc++; s = t; } if (out_cmd->argc == 0) return -1; return (int)((le + 2) - buf); } int resp_parse_one_cmd(const uint8_t *buf, int len, resp_cmd_t *out_cmd) { if (!buf || len <= 0 || !out_cmd) return -1; memset(out_cmd, 0, sizeof(*out_cmd)); const uint8_t *p = buf; const uint8_t *end = buf + (size_t)len; if (need(p, end, 1) < 0) return 0; if (*p != '*') { /* inline */ return parse_inline(buf, len, out_cmd); } /* multi bulk */ int64_t n64 = 0; int r = parse_array_len(&p, end, &n64); if (r == 0) return 0; if (r < 0) return -1; if (n64 <= 0 || n64 > (int64_t)RESP_MAX_ARGC) return -1; out_cmd->argc = (uint32_t)n64; /* scan + parse each bulk string */ for (uint32_t i = 0; i < out_cmd->argc; i++) { resp_slice_t sl = {0}; int rr = parse_bulk(&p, end, &sl); if (rr == 0) return 0; if (rr < 0) return -1; out_cmd->argv[i] = sl; } return (int)(p - buf); } /* ----------------- RESP2 builder ----------------- */ static int write_bytes(uint8_t **pp, const uint8_t *end, const void *src, size_t n) { uint8_t *p = *pp; if ((size_t)(end - p) < n) return -1; memcpy(p, src, n); p += n; *pp = p; return 0; } static int write_crlf(uint8_t **pp, const uint8_t *end) { return write_bytes(pp, end, "\r\n", 2); } static int write_i64_ascii(uint8_t **pp, const uint8_t *end, int64_t x) { char tmp[64]; int n = snprintf(tmp, sizeof(tmp), "%lld", (long long)x); if (n <= 0) return -1; return write_bytes(pp, end, tmp, (size_t)n); } static int write_u32_ascii(uint8_t **pp, const uint8_t *end, uint32_t x) { char tmp[32]; int n = snprintf(tmp, sizeof(tmp), "%u", (unsigned)x); if (n <= 0) return -1; return write_bytes(pp, end, tmp, (size_t)n); } int resp_build_value(const resp_value_t *v, uint8_t *out, size_t cap) { if (!v || !out || cap == 0) return -1; uint8_t *p = out; const uint8_t *end = out + cap; switch (v->type) { case RESP_T_SIMPLE_STR: if (write_bytes(&p, end, "+", 1) < 0) return -1; if (write_bytes(&p, end, v->bulk.ptr, v->bulk.len) < 0) return -1; if (write_crlf(&p, end) < 0) return -1; break; case RESP_T_ERROR: if (write_bytes(&p, end, "-", 1) < 0) return -1; if (write_bytes(&p, end, v->bulk.ptr, v->bulk.len) < 0) return -1; if (write_crlf(&p, end) < 0) return -1; break; case RESP_T_INTEGER: if (write_bytes(&p, end, ":", 1) < 0) return -1; if (write_i64_ascii(&p, end, v->i64) < 0) return -1; if (write_crlf(&p, end) < 0) return -1; break; case RESP_T_NIL: if (write_bytes(&p, end, "$-1\r\n", 5) < 0) return -1; break; case RESP_T_BULK_STR: if (write_bytes(&p, end, "$", 1) < 0) return -1; if (write_u32_ascii(&p, end, v->bulk.len) < 0) return -1; if (write_crlf(&p, end) < 0) return -1; if (v->bulk.len > 0 && v->bulk.ptr) { if (write_bytes(&p, end, v->bulk.ptr, v->bulk.len) < 0) return -1; } if (write_crlf(&p, end) < 0) return -1; break; default: return -1; } return (int)(p - out); } /* helpers */ resp_value_t resp_simple(const char *s) { resp_value_t v; v.type = RESP_T_SIMPLE_STR; v.i64 = 0; v.bulk.ptr = (const uint8_t*)s; v.bulk.len = (uint32_t)strlen(s); return v; } resp_value_t resp_error(const char *s) { resp_value_t v; v.type = RESP_T_ERROR; v.i64 = 0; v.bulk.ptr = (const uint8_t*)s; v.bulk.len = (uint32_t)strlen(s); return v; } resp_value_t resp_int(int64_t x) { resp_value_t v; v.type = RESP_T_INTEGER; v.i64 = x; v.bulk.ptr = NULL; v.bulk.len = 0; return v; } resp_value_t resp_bulk(const uint8_t *p, uint32_t n) { resp_value_t v; v.type = RESP_T_BULK_STR; v.i64 = 0; v.bulk.ptr = p; v.bulk.len = n; return v; } resp_value_t resp_nil(void) { resp_value_t v; v.type = RESP_T_NIL; v.i64 = 0; v.bulk.ptr = NULL; v.bulk.len = 0; return v; } /* ----------------- dispatcher (minimal) ----------------- */ static int expect_argv(const resp_cmd_t *cmd, uint32_t n) { return (cmd && cmd->argc == n) ? 0 : -1; } const char *command[] = { "SET", "GET", "DEL", "MOD", "EXIST", "RSET", "RGET", "RDEL", "RMOD", "REXIST", "HSET", "HGET", "HDEL", "HMOD", "HEXIST", "SAVE", "PSYNC" }; /** * 输入:cmd * 输出:out_value * 返回:-1 失败,参数错误,0 成功 */ int resp_dispatch(const resp_cmd_t *cmd, resp_value_t *out_value) { if (!cmd || !out_value) return -1; if (cmd->argc == 0 || cmd->argv[0].ptr == NULL) { *out_value = resp_error("ERR empty command"); return 0; } const uint8_t *cptr = cmd->argv[0].ptr; uint32_t clen = cmd->argv[0].len; kvs_cmd_t op = KVS_CMD_COUNT; for(kvs_cmd_t i = KVS_CMD_START; i < KVS_CMD_COUNT; ++ i){ if(ascii_casecmp(cptr, clen, command[i]) == 0){ op = i; break; } } if (op == KVS_CMD_COUNT) { *out_value = resp_error("ERR unknown command"); return -1; } switch (op) { #if ENABLE_ARRAY case KVS_CMD_SET: { if (cmd->argc != 3) { *out_value = resp_error("ERR wrong number of arguments for 'set'"); return -1; } // <0 error; 0 success; 1 exist int r = kvs_array_set_bin(&global_array, cmd->argv[1].ptr, cmd->argv[1].len, cmd->argv[2].ptr, cmd->argv[2].len); if (r < 0) { *out_value = resp_error("ERR internal error"); return 0; } else if (r == 1) { *out_value = resp_error("ERR key has exist"); return 0; } *out_value = resp_simple("OK"); return 0; } case KVS_CMD_GET: { if (cmd->argc != 2) { *out_value = resp_error("ERR wrong number of arguments for 'get'"); return -1; } uint32_t vlen = 0; // NULL not exist, NOTNULL exist const char *v = kvs_array_get_bin(&global_array, cmd->argv[1].ptr, cmd->argv[1].len, &vlen); if (!v) { *out_value = resp_nil(); return 0; } *out_value = resp_bulk((const uint8_t*)v, vlen); return 0; } case KVS_CMD_DEL: { if (cmd->argc != 2) { *out_value = resp_error("ERR wrong number of arguments for 'del'"); return -1; } // <0 error; =0 success; >0 no exist int r = kvs_array_del_bin(&global_array, cmd->argv[1].ptr, cmd->argv[1].len); if (r < 0) { *out_value = resp_error("ERR internal error"); return 0; } // r == 0, del 1; r > 0, del 0. *out_value = resp_int((r == 0) ? 1 : 0); return 0; } case KVS_CMD_MOD: { if (cmd->argc != 3) { *out_value = resp_error("ERR wrong number of arguments for 'mod'"); return -1; } // <0 error; =0 success; >0 no exist int r = kvs_array_mod_bin(&global_array, cmd->argv[1].ptr, cmd->argv[1].len, cmd->argv[2].ptr, cmd->argv[2].len); if (r < 0) { *out_value = resp_error("ERR internal error"); return 0; } if (r == 0) *out_value = resp_simple("OK"); else *out_value = resp_error("ERR no such key"); return 0; } case KVS_CMD_EXIST: { if (cmd->argc != 2) { *out_value = resp_error("ERR wrong number of arguments for 'exist'"); return -1; } // =0 exist, =1 no exist int r = kvs_array_exist_bin(&global_array, cmd->argv[1].ptr, cmd->argv[1].len); if (r < 0) { *out_value = resp_error("ERR internal error"); return 0; } *out_value = resp_int((r == 0) ? 1 : 0); return 0; } #endif #if ENABLE_RBTREE case KVS_CMD_RSET: { if (cmd->argc != 3) { *out_value = resp_error("ERR wrong number of arguments for 'rset'"); return 0; } // <0 error; 0 success; 1 exist int r = kvs_rbtree_set(&global_rbtree, cmd->argv[1].ptr, cmd->argv[1].len, cmd->argv[2].ptr, cmd->argv[2].len); if (r < 0) { *out_value = resp_error("ERR internal error"); return 0; } *out_value = resp_simple("OK"); return 0; } case KVS_CMD_RGET: { if (cmd->argc != 2) { *out_value = resp_error("ERR wrong number of arguments for 'rget'"); return 0; } uint32_t vlen = 0; // NULL notexist, NOTNULL exist。out_value_len 是长度。 const char *v = kvs_rbtree_get(&global_rbtree, cmd->argv[1].ptr, cmd->argv[1].len, &vlen); if (!v) { *out_value = resp_nil(); return 0; } *out_value = resp_bulk((const uint8_t*)v, vlen); return 0; } case KVS_CMD_RDEL: { if (cmd->argc != 2) { *out_value = resp_error("ERR wrong number of arguments for 'rdel'"); return 0; } // <0 error; =0 success; >0 no exist int r = kvs_rbtree_del(&global_rbtree, cmd->argv[1].ptr, cmd->argv[1].len); if (r < 0) { *out_value = resp_error("ERR internal error"); return 0; } *out_value = resp_int((r == 0) ? 1 : 0); return 0; } case KVS_CMD_RMOD: { if (cmd->argc != 3) { *out_value = resp_error("ERR wrong number of arguments for 'rmod'"); return 0; } // < 0 error; =0 success; >0 no exist int r = kvs_rbtree_mod(&global_rbtree, cmd->argv[1].ptr, cmd->argv[1].len, cmd->argv[2].ptr, cmd->argv[2].len); if (r < 0) { *out_value = resp_error("ERR internal error"); return 0; } if (r == 0) *out_value = resp_simple("OK"); else *out_value = resp_error("ERR no such key"); return 0; } case KVS_CMD_REXIST: { if (cmd->argc != 2) { *out_value = resp_error("ERR wrong number of arguments for 'rexist'"); return 0; } // =0 exist, =1 no exist int r = kvs_rbtree_exist(&global_rbtree, cmd->argv[1].ptr, cmd->argv[1].len); if (r < 0) { *out_value = resp_error("ERR internal error"); return 0; } *out_value = resp_int((r == 0) ? 1 : 0); return 0; } #endif #if ENABLE_HASH case KVS_CMD_HSET: { if (cmd->argc != 3) { *out_value = resp_error("ERR wrong number of arguments for 'hset'"); return 0; } // <0 error; 0 success; 1 exist int r = kvs_hash_set_bin(&global_hash, cmd->argv[1].ptr, cmd->argv[1].len, cmd->argv[2].ptr, cmd->argv[2].len); if (r < 0) { *out_value = resp_error("ERR internal error"); return 0; } *out_value = resp_simple("OK"); return 0; } case KVS_CMD_HGET: { if (cmd->argc != 2) { *out_value = resp_error("ERR wrong number of arguments for 'hget'"); return 0; } uint32_t vlen = 0; // NULL notexist, NOTNULL exist。out_value_len 是长度。 const char *v = kvs_hash_get_bin(&global_hash, cmd->argv[1].ptr, cmd->argv[1].len, &vlen); if (!v) { *out_value = resp_nil(); return 0; } *out_value = resp_bulk((const uint8_t*)v, vlen); return 0; } case KVS_CMD_HDEL: { if (cmd->argc != 2) { *out_value = resp_error("ERR wrong number of arguments for 'hdel'"); return 0; } // <0 error; =0 success; >0 no exist int r = kvs_hash_del_bin(&global_hash, cmd->argv[1].ptr, cmd->argv[1].len); if (r < 0) { *out_value = resp_error("ERR internal error"); return 0; } *out_value = resp_int((r == 0) ? 1 : 0); return 0; } case KVS_CMD_HMOD: { if (cmd->argc != 3) { *out_value = resp_error("ERR wrong number of arguments for 'hmod'"); return 0; } // <0 error; =0 success; >0 no exist int r = kvs_hash_mod_bin(&global_hash, cmd->argv[1].ptr, cmd->argv[1].len, cmd->argv[2].ptr, cmd->argv[2].len); if (r < 0) { *out_value = resp_error("ERR internal error"); return 0; } if (r == 0) *out_value = resp_simple("OK"); else *out_value = resp_error("ERR no such key"); return 0; } case KVS_CMD_HEXIST: { if (cmd->argc != 2) { *out_value = resp_error("ERR wrong number of arguments for 'hexist'"); return 0; } // =0 exist, =1 no exist int r = kvs_hash_exist_bin(&global_hash, cmd->argv[1].ptr, cmd->argv[1].len); if (r < 0) { *out_value = resp_error("ERR internal error"); return 0; } *out_value = resp_int((r == 0) ? 1 : 0); return 0; } #endif /* ---------------- misc ---------------- */ case KVS_CMD_SAVE: { if (cmd->argc != 1) { *out_value = resp_error("ERR wrong number of arguments for 'save'"); return 0; } int r = kvs_save_to_file(); if (r < 0) { *out_value = resp_error("ERR save failed"); return 0; } *out_value = resp_simple("OK"); return 0; } case KVS_CMD_PSYNC: *out_value = resp_simple("OK"); return 0; default: break; } *out_value = resp_error("ERR unknown command"); return 0; }