Files
mchat/faiss/faiss_manager.py
2026-04-10 12:47:39 +00:00

200 lines
7.8 KiB
Python

import os
import uuid
import faiss
import numpy as np
from config import get_settings
settings = get_settings()
class FaissManager:
def __init__(self):
self.dim = settings.FAISS_DIM
self.index_dir = settings.FAISS_INDEX_DIR
self.global_dir = os.path.join(self.index_dir, settings.FAISS_GLOBAL_DIR)
self.conversation_dir = os.path.join(self.index_dir, settings.FAISS_CONVERSATION_DIR)
self.global_index_path = os.path.join(self.global_dir, settings.FAISS_GLOBAL_INDEX_NAME)
self.legacy_index_path = settings.FAISS_INDEX_PATH
self.use_cosine = settings.USE_COSINE_SIMILARITY
self.global_index = None
self.conversation_indexes = {}
self._load_indexes()
def _load_indexes(self):
os.makedirs(self.global_dir, exist_ok=True)
os.makedirs(self.conversation_dir, exist_ok=True)
if os.path.isfile(self.global_index_path):
self.global_index = self._load_index_from_path(self.global_index_path)
print(f"✅ 加载全局索引:{self.global_index.ntotal} 个向量")
elif os.path.isfile(self.legacy_index_path):
legacy_index = faiss.read_index(self.legacy_index_path)
self.global_index = self._ensure_id_map(legacy_index)
self._persist_index(self.global_index, self.global_index_path)
print(f"✅ 从旧索引迁移全局索引:{self.global_index.ntotal} 个向量")
else:
self.global_index = self._create_index("global")
self._persist_index(self.global_index, self.global_index_path)
print("✅ 创建全局索引")
for file_name in sorted(os.listdir(self.conversation_dir)):
if not file_name.endswith(".index"):
continue
idx = file_name[: -len(".index")]
index_path = self._conversation_index_path(idx)
self.conversation_indexes[idx] = self._load_index_from_path(index_path)
if self.conversation_indexes:
print(f"✅ 加载对话索引:{len(self.conversation_indexes)}")
def _create_base_index(self, kind: str):
if kind == "conversation":
if self.use_cosine:
return faiss.IndexFlatIP(self.dim)
return faiss.IndexFlatL2(self.dim)
if settings.FAISS_INDEX_TYPE == "HNSW":
if self.use_cosine:
index = faiss.IndexHNSWFlat(self.dim, settings.HNSW_M, faiss.METRIC_INNER_PRODUCT)
else:
index = faiss.IndexHNSWFlat(self.dim, settings.HNSW_M)
index.hnsw.efConstruction = settings.HNSW_EF_CONSTRUCTION
index.hnsw.efSearch = settings.HNSW_EF_SEARCH
return index
if settings.FAISS_INDEX_TYPE == "FlatIP" and self.use_cosine:
return faiss.IndexFlatIP(self.dim)
return faiss.IndexFlatL2(self.dim)
def _create_index(self, kind: str):
return faiss.IndexIDMap2(self._create_base_index(kind))
def _normalize(self, embedding: list[float]) -> np.ndarray:
vec = np.array(embedding, dtype=np.float32)
norm = np.linalg.norm(vec)
return vec / norm if norm > 0 else vec
def _prepare_embedding(self, embedding: list[float]) -> np.ndarray:
if len(embedding) != self.dim:
raise ValueError(f"Embedding 维度错误,应为 {self.dim}")
vec = self._normalize(embedding) if self.use_cosine else np.array(embedding, dtype=np.float32)
return vec.reshape(1, -1)
def _conversation_index_path(self, idx: str) -> str:
return os.path.join(self.conversation_dir, f"{idx}.index")
def _persist_index(self, index, path: str):
os.makedirs(os.path.dirname(path), exist_ok=True)
faiss.write_index(index, path)
def _load_index_from_path(self, path: str):
index = faiss.read_index(path)
return self._ensure_id_map(index)
def _ensure_id_map(self, index):
if hasattr(index, "id_map"):
return index
converted = faiss.IndexIDMap2(self._create_base_index("global"))
if index.ntotal == 0:
return converted
vectors = np.empty((index.ntotal, index.d), dtype=np.float32)
index.reconstruct_n(0, index.ntotal, vectors)
ids = np.arange(index.ntotal, dtype=np.int64)
converted.add_with_ids(vectors, ids)
return converted
def _index_ids(self, index) -> np.ndarray:
if not hasattr(index, "id_map"):
return np.arange(index.ntotal, dtype=np.int64)
return faiss.vector_to_array(index.id_map)
def _next_global_id(self) -> int:
ids = self._index_ids(self.global_index)
if ids.size == 0:
return 0
return int(ids.max()) + 1
def _get_conversation_index(self, idx: str):
if idx in self.conversation_indexes:
return self.conversation_indexes[idx]
index_path = self._conversation_index_path(idx)
if not os.path.isfile(index_path):
raise ValueError(f"索引不存在: {idx}")
index = self._load_index_from_path(index_path)
self.conversation_indexes[idx] = index
return index
def create(self) -> str:
idx = f"conv_{uuid.uuid4().hex}"
index = self._create_index("conversation")
self.conversation_indexes[idx] = index
self._persist_index(index, self._conversation_index_path(idx))
return idx
def insert(self, embedding: list[float]) -> int:
vector_id = self._next_global_id()
self.insert_global_with_id(vector_id, embedding)
return vector_id
def insert_global_with_id(self, vector_id: int, embedding: list[float]) -> int:
vec = self._prepare_embedding(embedding)
ids = np.array([vector_id], dtype=np.int64)
self.global_index.add_with_ids(vec, ids)
self._persist_index(self.global_index, self.global_index_path)
return vector_id
def insert_idx(self, idx: str, vector_id: int, embedding: list[float]) -> int:
index = self._get_conversation_index(idx)
vec = self._prepare_embedding(embedding)
ids = np.array([vector_id], dtype=np.int64)
index.add_with_ids(vec, ids)
self._persist_index(index, self._conversation_index_path(idx))
return vector_id
def _search(self, index, embedding: list[float], k: int = 5):
vec = self._prepare_embedding(embedding)
distances, indices = index.search(vec, k)
ids = indices[0].tolist()
raw_distances = distances[0].tolist()
normalized_distances = []
similarity_scores = []
for vector_id, distance in zip(ids, raw_distances):
if vector_id < 0:
normalized_distances.append(0.0)
similarity_scores.append(0.0)
continue
normalized_distances.append(distance)
similarity_scores.append(1 - distance if not self.use_cosine else distance)
return {
"ids": ids,
"distances": normalized_distances,
"similarity_scores": similarity_scores,
}
def search(self, embedding: list[float], k: int = 5):
return self._search(self.global_index, embedding, k)
def search_idx(self, idx: str, embedding: list[float], k: int = 5):
index = self._get_conversation_index(idx)
return self._search(index, embedding, k)
def delete(self, idx: str) -> bool:
self._get_conversation_index(idx)
self.conversation_indexes.pop(idx, None)
index_path = self._conversation_index_path(idx)
if os.path.exists(index_path):
os.remove(index_path)
return True
def persist(self):
self._persist_index(self.global_index, self.global_index_path)
for idx, index in self.conversation_indexes.items():
self._persist_index(index, self._conversation_index_path(idx))
return True
faiss_manager = FaissManager()