import os import uuid import faiss import numpy as np from config import get_settings settings = get_settings() class FaissManager: def __init__(self): self.dim = settings.FAISS_DIM self.index_dir = settings.FAISS_INDEX_DIR self.global_dir = os.path.join(self.index_dir, settings.FAISS_GLOBAL_DIR) self.conversation_dir = os.path.join(self.index_dir, settings.FAISS_CONVERSATION_DIR) self.global_index_path = os.path.join(self.global_dir, settings.FAISS_GLOBAL_INDEX_NAME) self.legacy_index_path = settings.FAISS_INDEX_PATH self.use_cosine = settings.USE_COSINE_SIMILARITY self.global_index = None self.conversation_indexes = {} self._load_indexes() def _load_indexes(self): os.makedirs(self.global_dir, exist_ok=True) os.makedirs(self.conversation_dir, exist_ok=True) if os.path.isfile(self.global_index_path): self.global_index = self._load_index_from_path(self.global_index_path) print(f"✅ 加载全局索引:{self.global_index.ntotal} 个向量") elif os.path.isfile(self.legacy_index_path): legacy_index = faiss.read_index(self.legacy_index_path) self.global_index = self._ensure_id_map(legacy_index) self._persist_index(self.global_index, self.global_index_path) print(f"✅ 从旧索引迁移全局索引:{self.global_index.ntotal} 个向量") else: self.global_index = self._create_index("global") self._persist_index(self.global_index, self.global_index_path) print("✅ 创建全局索引") for file_name in sorted(os.listdir(self.conversation_dir)): if not file_name.endswith(".index"): continue idx = file_name[: -len(".index")] index_path = self._conversation_index_path(idx) self.conversation_indexes[idx] = self._load_index_from_path(index_path) if self.conversation_indexes: print(f"✅ 加载对话索引:{len(self.conversation_indexes)} 个") def _create_base_index(self, kind: str): if kind == "conversation": if self.use_cosine: return faiss.IndexFlatIP(self.dim) return faiss.IndexFlatL2(self.dim) if settings.FAISS_INDEX_TYPE == "HNSW": if self.use_cosine: index = faiss.IndexHNSWFlat(self.dim, settings.HNSW_M, faiss.METRIC_INNER_PRODUCT) else: index = faiss.IndexHNSWFlat(self.dim, settings.HNSW_M) index.hnsw.efConstruction = settings.HNSW_EF_CONSTRUCTION index.hnsw.efSearch = settings.HNSW_EF_SEARCH return index if settings.FAISS_INDEX_TYPE == "FlatIP" and self.use_cosine: return faiss.IndexFlatIP(self.dim) return faiss.IndexFlatL2(self.dim) def _create_index(self, kind: str): return faiss.IndexIDMap2(self._create_base_index(kind)) def _normalize(self, embedding: list[float]) -> np.ndarray: vec = np.array(embedding, dtype=np.float32) norm = np.linalg.norm(vec) return vec / norm if norm > 0 else vec def _prepare_embedding(self, embedding: list[float]) -> np.ndarray: if len(embedding) != self.dim: raise ValueError(f"Embedding 维度错误,应为 {self.dim}") vec = self._normalize(embedding) if self.use_cosine else np.array(embedding, dtype=np.float32) return vec.reshape(1, -1) def _conversation_index_path(self, idx: str) -> str: return os.path.join(self.conversation_dir, f"{idx}.index") def _persist_index(self, index, path: str): os.makedirs(os.path.dirname(path), exist_ok=True) faiss.write_index(index, path) def _load_index_from_path(self, path: str): index = faiss.read_index(path) return self._ensure_id_map(index) def _ensure_id_map(self, index): if hasattr(index, "id_map"): return index converted = faiss.IndexIDMap2(self._create_base_index("global")) if index.ntotal == 0: return converted vectors = np.empty((index.ntotal, index.d), dtype=np.float32) index.reconstruct_n(0, index.ntotal, vectors) ids = np.arange(index.ntotal, dtype=np.int64) converted.add_with_ids(vectors, ids) return converted def _index_ids(self, index) -> np.ndarray: if not hasattr(index, "id_map"): return np.arange(index.ntotal, dtype=np.int64) return faiss.vector_to_array(index.id_map) def _next_global_id(self) -> int: ids = self._index_ids(self.global_index) if ids.size == 0: return 0 return int(ids.max()) + 1 def _get_conversation_index(self, idx: str): if idx in self.conversation_indexes: return self.conversation_indexes[idx] index_path = self._conversation_index_path(idx) if not os.path.isfile(index_path): raise ValueError(f"索引不存在: {idx}") index = self._load_index_from_path(index_path) self.conversation_indexes[idx] = index return index def create(self) -> str: idx = f"conv_{uuid.uuid4().hex}" index = self._create_index("conversation") self.conversation_indexes[idx] = index self._persist_index(index, self._conversation_index_path(idx)) return idx def insert(self, embedding: list[float]) -> int: vector_id = self._next_global_id() self.insert_global_with_id(vector_id, embedding) return vector_id def insert_global_with_id(self, vector_id: int, embedding: list[float]) -> int: vec = self._prepare_embedding(embedding) ids = np.array([vector_id], dtype=np.int64) self.global_index.add_with_ids(vec, ids) self._persist_index(self.global_index, self.global_index_path) return vector_id def insert_idx(self, idx: str, vector_id: int, embedding: list[float]) -> int: index = self._get_conversation_index(idx) vec = self._prepare_embedding(embedding) ids = np.array([vector_id], dtype=np.int64) index.add_with_ids(vec, ids) self._persist_index(index, self._conversation_index_path(idx)) return vector_id def _search(self, index, embedding: list[float], k: int = 5): vec = self._prepare_embedding(embedding) distances, indices = index.search(vec, k) ids = indices[0].tolist() raw_distances = distances[0].tolist() normalized_distances = [] similarity_scores = [] for vector_id, distance in zip(ids, raw_distances): if vector_id < 0: normalized_distances.append(0.0) similarity_scores.append(0.0) continue normalized_distances.append(distance) similarity_scores.append(1 - distance if not self.use_cosine else distance) return { "ids": ids, "distances": normalized_distances, "similarity_scores": similarity_scores, } def search(self, embedding: list[float], k: int = 5): return self._search(self.global_index, embedding, k) def search_idx(self, idx: str, embedding: list[float], k: int = 5): index = self._get_conversation_index(idx) return self._search(index, embedding, k) def delete(self, idx: str) -> bool: self._get_conversation_index(idx) self.conversation_indexes.pop(idx, None) index_path = self._conversation_index_path(idx) if os.path.exists(index_path): os.remove(index_path) return True def persist(self): self._persist_index(self.global_index, self.global_index_path) for idx, index in self.conversation_indexes.items(): self._persist_index(index, self._conversation_index_path(idx)) return True faiss_manager = FaissManager()