diff --git a/ai-chat-service/docker.config.yaml b/ai-chat-service/docker.config.yaml index 2b35dc6..e819b02 100644 --- a/ai-chat-service/docker.config.yaml +++ b/ai-chat-service/docker.config.yaml @@ -17,9 +17,9 @@ chat: bot_desc: "你是一个AI助手,我需要你模拟一名资深的软件工程师来回答我的问题" min_response_tokens: 600 redis: - host: "host.docker.internal" - port: 8888 - pwd: "123456" + host: "redis" + port: 6379 + pwd: "" dependOn: sensitive: address: "sensitive-filter:50053" @@ -36,7 +36,7 @@ embedding: model: "embedding-2" timeout: 10 faiss: - base_url: "http://host.docker.internal:8451" + base_url: "http://faiss:8000" search_k: 1 similarity_threshold: 0.9 timeout: 10 diff --git a/ai-chat-stack/compose.yaml b/ai-chat-stack/compose.yaml index de81ab2..633431f 100644 --- a/ai-chat-stack/compose.yaml +++ b/ai-chat-stack/compose.yaml @@ -1,4 +1,18 @@ services: + faiss: + build: + context: ../faiss + container_name: faiss-service + ports: + - "8451:8000" + volumes: + - ../faiss/indexes:/app/indexes # 持久化索引目录 + - ../faiss/.env:/app/.env # 可选:挂载配置 + restart: unless-stopped + environment: + - FAISS_DIM=1024 + - APP_PORT=8000 + redis: image: redis:7-alpine container_name: ai-chat-redis @@ -60,12 +74,10 @@ services: - .env volumes: - ./configs/ai-chat-service.yaml:/app/config.yaml:ro - extra_hosts: - - "host.docker.internal:host-gateway" ports: - "50055:50055" depends_on: - - ai-chat-redis + - faiss - tokenizer - sensitive-filter - keywords-filter diff --git a/ai-chat-stack/configs/ai-chat-service.yaml b/ai-chat-stack/configs/ai-chat-service.yaml index 9da615b..2361c76 100644 --- a/ai-chat-stack/configs/ai-chat-service.yaml +++ b/ai-chat-stack/configs/ai-chat-service.yaml @@ -36,7 +36,7 @@ embedding: model: "embedding-2" timeout: 10 faiss: - base_url: "http://host.docker.internal:8451" + base_url: "http://faiss:8000" search_k: 1 similarity_threshold: 0.9 timeout: 10 diff --git a/faiss/.gitignore b/faiss/.gitignore index 1836f4c..9db062a 100644 --- a/faiss/.gitignore +++ b/faiss/.gitignore @@ -1,3 +1,6 @@ -faiss_index.bin +indexes/* +!indexes/.gitkeep +!indexes/global/.gitkeep +!indexes/conversations/.gitkeep .vscode -__pycache__ \ No newline at end of file +__pycache__ diff --git a/faiss/api.py b/faiss/api.py index e99a3a5..5b5b505 100644 --- a/faiss/api.py +++ b/faiss/api.py @@ -1,41 +1,81 @@ -# api.py -from fastapi import FastAPI, Depends, HTTPException -from models import EmbeddingInput, SearchInput -from faiss_manager import faiss_manager +from fastapi import FastAPI, HTTPException + from config import get_settings +from faiss_manager import faiss_manager +from models import EmbeddingInput, IndexDeleteInput, IndexInsertInput, IndexSearchInput, SearchInput settings = get_settings() app = FastAPI( title="FAISS 服务", - description="向量插入 + 相似搜索 + 持久化", - version="1.0.0" + description="向量插入 + 相似搜索 + 多索引管理", + version="1.0.0", ) + def log_business(message: str): if settings.ENABLE_REQUEST_LOGS: print(message, flush=True) + @app.post("/insert") async def insert(data: EmbeddingInput): try: vector_id = faiss_manager.insert(data.embedding) - log_business(f"[faiss] insert id={vector_id}") + log_business(f"[faiss] global insert id={vector_id}") return {"id": vector_id} - except Exception as e: - raise HTTPException(status_code=400, detail=str(e)) + except Exception as exc: + raise HTTPException(status_code=400, detail=str(exc)) @app.post("/search") async def search(data: SearchInput): try: result = faiss_manager.search(data.embedding, data.k) - log_business( - f"[faiss] search ids={result['ids']} similarity_scores={result['similarity_scores']}", - ) + log_business(f"[faiss] global search ids={result['ids']} similarity_scores={result['similarity_scores']}") return result - except Exception as e: - raise HTTPException(status_code=400, detail=str(e)) + except Exception as exc: + raise HTTPException(status_code=400, detail=str(exc)) + + +@app.post("/create") +async def create(): + try: + idx = faiss_manager.create() + log_business(f"[faiss] create idx={idx}") + return {"idx": idx} + except Exception as exc: + raise HTTPException(status_code=400, detail=str(exc)) + + +@app.post("/insert_idx") +async def insert_idx(data: IndexInsertInput): + try: + vector_id = faiss_manager.insert_idx(data.idx, data.id, data.embedding) + log_business(f"[faiss] insert_idx idx={data.idx} id={vector_id}") + return {"id": vector_id, "idx": data.idx} + except Exception as exc: + raise HTTPException(status_code=400, detail=str(exc)) + + +@app.post("/search_idx") +async def search_idx(data: IndexSearchInput): + try: + result = faiss_manager.search_idx(data.idx, data.embedding, data.k) + log_business(f"[faiss] search_idx idx={data.idx} ids={result['ids']} similarity_scores={result['similarity_scores']}") + return result + except Exception as exc: + raise HTTPException(status_code=400, detail=str(exc)) + + +@app.post("/del") +async def delete_idx(data: IndexDeleteInput): + try: + faiss_manager.delete(data.idx) + log_business(f"[faiss] delete idx={data.idx}") + return {"status": "success", "idx": data.idx} + except Exception as exc: + raise HTTPException(status_code=400, detail=str(exc)) @app.post("/persist") diff --git a/faiss/config.py b/faiss/config.py index b0c1fb0..e314f69 100644 --- a/faiss/config.py +++ b/faiss/config.py @@ -5,6 +5,10 @@ from functools import lru_cache class Settings(BaseSettings): # FAISS 配置(已优化为你的 LLM 相似问题缓存场景) FAISS_DIM: int = 1024 # 根据你的 embedding 模型修改(e.g. bge-large=1024, text-embedding-3-large=3072) + FAISS_INDEX_DIR: str = "indexes" + FAISS_GLOBAL_DIR: str = "global" + FAISS_CONVERSATION_DIR: str = "conversations" + FAISS_GLOBAL_INDEX_NAME: str = "global_qa.index" FAISS_INDEX_PATH: str = "faiss_index.bin" FAISS_INDEX_TYPE: str = "HNSW" # 默认改为 HNSW(最推荐) diff --git a/faiss/docker-compose.yml b/faiss/docker-compose.yml deleted file mode 100644 index 2186738..0000000 --- a/faiss/docker-compose.yml +++ /dev/null @@ -1,13 +0,0 @@ -services: - faiss: - build: . - container_name: faiss-service - ports: - - "8451:8000" - volumes: - - ./faiss_index.bin:/app/faiss_index.bin # 持久化索引文件 - - ./.env:/app/.env # 可选:挂载配置 - restart: unless-stopped - environment: - - FAISS_DIM=1024 - - APP_PORT=8000 \ No newline at end of file diff --git a/faiss/faiss_manager.py b/faiss/faiss_manager.py index ce9a38a..28dc2cc 100644 --- a/faiss/faiss_manager.py +++ b/faiss/faiss_manager.py @@ -1,87 +1,199 @@ -# faiss_manager.py import os -import numpy as np +import uuid + import faiss +import numpy as np + from config import get_settings settings = get_settings() + class FaissManager: def __init__(self): self.dim = settings.FAISS_DIM - self.index_path = settings.FAISS_INDEX_PATH + self.index_dir = settings.FAISS_INDEX_DIR + self.global_dir = os.path.join(self.index_dir, settings.FAISS_GLOBAL_DIR) + self.conversation_dir = os.path.join(self.index_dir, settings.FAISS_CONVERSATION_DIR) + self.global_index_path = os.path.join(self.global_dir, settings.FAISS_GLOBAL_INDEX_NAME) + self.legacy_index_path = settings.FAISS_INDEX_PATH self.use_cosine = settings.USE_COSINE_SIMILARITY - self.index = None - self._load_or_create_index() + self.global_index = None + self.conversation_indexes = {} + self._load_indexes() - def _load_or_create_index(self): - if os.path.exists(self.index_path): - self.index = faiss.read_index(self.index_path) - print(f"✅ 加载已有索引:{self.index.ntotal} 个向量,维度={self.index.d}") - return + def _load_indexes(self): + os.makedirs(self.global_dir, exist_ok=True) + os.makedirs(self.conversation_dir, exist_ok=True) + + if os.path.isfile(self.global_index_path): + self.global_index = self._load_index_from_path(self.global_index_path) + print(f"✅ 加载全局索引:{self.global_index.ntotal} 个向量") + elif os.path.isfile(self.legacy_index_path): + legacy_index = faiss.read_index(self.legacy_index_path) + self.global_index = self._ensure_id_map(legacy_index) + self._persist_index(self.global_index, self.global_index_path) + print(f"✅ 从旧索引迁移全局索引:{self.global_index.ntotal} 个向量") + else: + self.global_index = self._create_index("global") + self._persist_index(self.global_index, self.global_index_path) + print("✅ 创建全局索引") + + for file_name in sorted(os.listdir(self.conversation_dir)): + if not file_name.endswith(".index"): + continue + idx = file_name[: -len(".index")] + index_path = self._conversation_index_path(idx) + self.conversation_indexes[idx] = self._load_index_from_path(index_path) + if self.conversation_indexes: + print(f"✅ 加载对话索引:{len(self.conversation_indexes)} 个") + + def _create_base_index(self, kind: str): + if kind == "conversation": + if self.use_cosine: + return faiss.IndexFlatIP(self.dim) + return faiss.IndexFlatL2(self.dim) - # 创建新索引 if settings.FAISS_INDEX_TYPE == "HNSW": if self.use_cosine: - self.index = faiss.IndexHNSWFlat(self.dim, settings.HNSW_M, faiss.METRIC_INNER_PRODUCT) - print("✅ 创建 HNSWIP 索引(余弦相似度)") + index = faiss.IndexHNSWFlat(self.dim, settings.HNSW_M, faiss.METRIC_INNER_PRODUCT) else: - self.index = faiss.IndexHNSWFlat(self.dim, settings.HNSW_M) - print("✅ 创建 HNSWFlat 索引(L2 距离)") - - # 设置 HNSW 参数 - self.index.hnsw.efConstruction = settings.HNSW_EF_CONSTRUCTION - self.index.hnsw.efSearch = settings.HNSW_EF_SEARCH - print(f" HNSW 参数: M={settings.HNSW_M}, efConstruction={settings.HNSW_EF_CONSTRUCTION}, efSearch={settings.HNSW_EF_SEARCH}") + index = faiss.IndexHNSWFlat(self.dim, settings.HNSW_M) + index.hnsw.efConstruction = settings.HNSW_EF_CONSTRUCTION + index.hnsw.efSearch = settings.HNSW_EF_SEARCH + return index + if settings.FAISS_INDEX_TYPE == "FlatIP" and self.use_cosine: + return faiss.IndexFlatIP(self.dim) + return faiss.IndexFlatL2(self.dim) - elif settings.FAISS_INDEX_TYPE == "FlatIP" and self.use_cosine: - self.index = faiss.IndexFlatIP(self.dim) - print("✅ 创建 FlatIP 索引(精确余弦)") - else: - # 默认精确 L2(兼容旧配置) - self.index = faiss.IndexFlatL2(self.dim) - print("✅ 创建 FlatL2 索引(精确欧式)") + def _create_index(self, kind: str): + return faiss.IndexIDMap2(self._create_base_index(kind)) def _normalize(self, embedding: list[float]) -> np.ndarray: - """L2 归一化(余弦相似度必需)""" vec = np.array(embedding, dtype=np.float32) norm = np.linalg.norm(vec) return vec / norm if norm > 0 else vec - def insert(self, embedding: list[float]) -> int: - """插入向量,返回 ID""" + def _prepare_embedding(self, embedding: list[float]) -> np.ndarray: if len(embedding) != self.dim: raise ValueError(f"Embedding 维度错误,应为 {self.dim}") - vec = self._normalize(embedding) if self.use_cosine else np.array(embedding, dtype=np.float32) - vec = vec.reshape(1, -1) - - idx = self.index.ntotal - self.index.add(vec) + return vec.reshape(1, -1) + + def _conversation_index_path(self, idx: str) -> str: + return os.path.join(self.conversation_dir, f"{idx}.index") + + def _persist_index(self, index, path: str): + os.makedirs(os.path.dirname(path), exist_ok=True) + faiss.write_index(index, path) + + def _load_index_from_path(self, path: str): + index = faiss.read_index(path) + return self._ensure_id_map(index) + + def _ensure_id_map(self, index): + if hasattr(index, "id_map"): + return index + + converted = faiss.IndexIDMap2(self._create_base_index("global")) + if index.ntotal == 0: + return converted + + vectors = np.empty((index.ntotal, index.d), dtype=np.float32) + index.reconstruct_n(0, index.ntotal, vectors) + ids = np.arange(index.ntotal, dtype=np.int64) + converted.add_with_ids(vectors, ids) + return converted + + def _index_ids(self, index) -> np.ndarray: + if not hasattr(index, "id_map"): + return np.arange(index.ntotal, dtype=np.int64) + return faiss.vector_to_array(index.id_map) + + def _next_global_id(self) -> int: + ids = self._index_ids(self.global_index) + if ids.size == 0: + return 0 + return int(ids.max()) + 1 + + def _get_conversation_index(self, idx: str): + if idx in self.conversation_indexes: + return self.conversation_indexes[idx] + + index_path = self._conversation_index_path(idx) + if not os.path.isfile(index_path): + raise ValueError(f"索引不存在: {idx}") + index = self._load_index_from_path(index_path) + self.conversation_indexes[idx] = index + return index + + def create(self) -> str: + idx = f"conv_{uuid.uuid4().hex}" + index = self._create_index("conversation") + self.conversation_indexes[idx] = index + self._persist_index(index, self._conversation_index_path(idx)) return idx - def search(self, embedding: list[float], k: int = 5): - """搜索相似向量(返回 id + 距离)""" - if len(embedding) != self.dim: - raise ValueError(f"Embedding 维度错误,应为 {self.dim}") + def insert(self, embedding: list[float]) -> int: + vector_id = self._next_global_id() + self.insert_global_with_id(vector_id, embedding) + return vector_id - vec = self._normalize(embedding) if self.use_cosine else np.array(embedding, dtype=np.float32) - vec = vec.reshape(1, -1) + def insert_global_with_id(self, vector_id: int, embedding: list[float]) -> int: + vec = self._prepare_embedding(embedding) + ids = np.array([vector_id], dtype=np.int64) + self.global_index.add_with_ids(vec, ids) + self._persist_index(self.global_index, self.global_index_path) + return vector_id - distances, indices = self.index.search(vec, k) - + def insert_idx(self, idx: str, vector_id: int, embedding: list[float]) -> int: + index = self._get_conversation_index(idx) + vec = self._prepare_embedding(embedding) + ids = np.array([vector_id], dtype=np.int64) + index.add_with_ids(vec, ids) + self._persist_index(index, self._conversation_index_path(idx)) + return vector_id + + def _search(self, index, embedding: list[float], k: int = 5): + vec = self._prepare_embedding(embedding) + distances, indices = index.search(vec, k) + ids = indices[0].tolist() + raw_distances = distances[0].tolist() + normalized_distances = [] + similarity_scores = [] + for vector_id, distance in zip(ids, raw_distances): + if vector_id < 0: + normalized_distances.append(0.0) + similarity_scores.append(0.0) + continue + normalized_distances.append(distance) + similarity_scores.append(1 - distance if not self.use_cosine else distance) return { - "ids": indices[0].tolist(), - "distances": distances[0].tolist(), # 余弦时:值越大越相似(1.0=完全相同) - "similarity_scores": [1 - d for d in distances[0].tolist()] if not self.use_cosine else distances[0].tolist() + "ids": ids, + "distances": normalized_distances, + "similarity_scores": similarity_scores, } + def search(self, embedding: list[float], k: int = 5): + return self._search(self.global_index, embedding, k) + + def search_idx(self, idx: str, embedding: list[float], k: int = 5): + index = self._get_conversation_index(idx) + return self._search(index, embedding, k) + + def delete(self, idx: str) -> bool: + self._get_conversation_index(idx) + self.conversation_indexes.pop(idx, None) + index_path = self._conversation_index_path(idx) + if os.path.exists(index_path): + os.remove(index_path) + return True + def persist(self): - """保存索引""" - faiss.write_index(self.index, self.index_path) - print(f"💾 索引已保存 → {self.index_path}(共 {self.index.ntotal} 个向量)") + self._persist_index(self.global_index, self.global_index_path) + for idx, index in self.conversation_indexes.items(): + self._persist_index(index, self._conversation_index_path(idx)) return True -# 单例 faiss_manager = FaissManager() diff --git a/faiss/indexes/.gitkeep b/faiss/indexes/.gitkeep new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/faiss/indexes/.gitkeep @@ -0,0 +1 @@ + diff --git a/faiss/models.py b/faiss/models.py index 3788506..e088879 100644 --- a/faiss/models.py +++ b/faiss/models.py @@ -1,10 +1,28 @@ -# models.py -from pydantic import BaseModel from typing import List +from pydantic import BaseModel + + class EmbeddingInput(BaseModel): embedding: List[float] + class SearchInput(BaseModel): embedding: List[float] - k: int = 5 \ No newline at end of file + k: int = 5 + + +class IndexInsertInput(BaseModel): + idx: str + id: int + embedding: List[float] + + +class IndexSearchInput(BaseModel): + idx: str + embedding: List[float] + k: int = 5 + + +class IndexDeleteInput(BaseModel): + idx: str