Files
mchat/faiss/faiss_manager.py
2026-04-10 11:55:00 +00:00

88 lines
3.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# faiss_manager.py
import os
import numpy as np
import faiss
from config import get_settings
settings = get_settings()
class FaissManager:
def __init__(self):
self.dim = settings.FAISS_DIM
self.index_path = settings.FAISS_INDEX_PATH
self.use_cosine = settings.USE_COSINE_SIMILARITY
self.index = None
self._load_or_create_index()
def _load_or_create_index(self):
if os.path.exists(self.index_path):
self.index = faiss.read_index(self.index_path)
print(f"✅ 加载已有索引:{self.index.ntotal} 个向量,维度={self.index.d}")
return
# 创建新索引
if settings.FAISS_INDEX_TYPE == "HNSW":
if self.use_cosine:
self.index = faiss.IndexHNSWFlat(self.dim, settings.HNSW_M, faiss.METRIC_INNER_PRODUCT)
print("✅ 创建 HNSWIP 索引(余弦相似度)")
else:
self.index = faiss.IndexHNSWFlat(self.dim, settings.HNSW_M)
print("✅ 创建 HNSWFlat 索引L2 距离)")
# 设置 HNSW 参数
self.index.hnsw.efConstruction = settings.HNSW_EF_CONSTRUCTION
self.index.hnsw.efSearch = settings.HNSW_EF_SEARCH
print(f" HNSW 参数: M={settings.HNSW_M}, efConstruction={settings.HNSW_EF_CONSTRUCTION}, efSearch={settings.HNSW_EF_SEARCH}")
elif settings.FAISS_INDEX_TYPE == "FlatIP" and self.use_cosine:
self.index = faiss.IndexFlatIP(self.dim)
print("✅ 创建 FlatIP 索引(精确余弦)")
else:
# 默认精确 L2兼容旧配置
self.index = faiss.IndexFlatL2(self.dim)
print("✅ 创建 FlatL2 索引(精确欧式)")
def _normalize(self, embedding: list[float]) -> np.ndarray:
"""L2 归一化(余弦相似度必需)"""
vec = np.array(embedding, dtype=np.float32)
norm = np.linalg.norm(vec)
return vec / norm if norm > 0 else vec
def insert(self, embedding: list[float]) -> int:
"""插入向量,返回 ID"""
if len(embedding) != self.dim:
raise ValueError(f"Embedding 维度错误,应为 {self.dim}")
vec = self._normalize(embedding) if self.use_cosine else np.array(embedding, dtype=np.float32)
vec = vec.reshape(1, -1)
idx = self.index.ntotal
self.index.add(vec)
return idx
def search(self, embedding: list[float], k: int = 5):
"""搜索相似向量(返回 id + 距离)"""
if len(embedding) != self.dim:
raise ValueError(f"Embedding 维度错误,应为 {self.dim}")
vec = self._normalize(embedding) if self.use_cosine else np.array(embedding, dtype=np.float32)
vec = vec.reshape(1, -1)
distances, indices = self.index.search(vec, k)
return {
"ids": indices[0].tolist(),
"distances": distances[0].tolist(), # 余弦时值越大越相似1.0=完全相同)
"similarity_scores": [1 - d for d in distances[0].tolist()] if not self.use_cosine else distances[0].tolist()
}
def persist(self):
"""保存索引"""
faiss.write_index(self.index, self.index_path)
print(f"💾 索引已保存 → {self.index_path}(共 {self.index.ntotal} 个向量)")
return True
# 单例
faiss_manager = FaissManager()