88 lines
3.4 KiB
Python
88 lines
3.4 KiB
Python
# faiss_manager.py
|
||
import os
|
||
import numpy as np
|
||
import faiss
|
||
from config import get_settings
|
||
|
||
settings = get_settings()
|
||
|
||
class FaissManager:
|
||
def __init__(self):
|
||
self.dim = settings.FAISS_DIM
|
||
self.index_path = settings.FAISS_INDEX_PATH
|
||
self.use_cosine = settings.USE_COSINE_SIMILARITY
|
||
self.index = None
|
||
self._load_or_create_index()
|
||
|
||
def _load_or_create_index(self):
|
||
if os.path.exists(self.index_path):
|
||
self.index = faiss.read_index(self.index_path)
|
||
print(f"✅ 加载已有索引:{self.index.ntotal} 个向量,维度={self.index.d}")
|
||
return
|
||
|
||
# 创建新索引
|
||
if settings.FAISS_INDEX_TYPE == "HNSW":
|
||
if self.use_cosine:
|
||
self.index = faiss.IndexHNSWFlat(self.dim, settings.HNSW_M, faiss.METRIC_INNER_PRODUCT)
|
||
print("✅ 创建 HNSWIP 索引(余弦相似度)")
|
||
else:
|
||
self.index = faiss.IndexHNSWFlat(self.dim, settings.HNSW_M)
|
||
print("✅ 创建 HNSWFlat 索引(L2 距离)")
|
||
|
||
# 设置 HNSW 参数
|
||
self.index.hnsw.efConstruction = settings.HNSW_EF_CONSTRUCTION
|
||
self.index.hnsw.efSearch = settings.HNSW_EF_SEARCH
|
||
print(f" HNSW 参数: M={settings.HNSW_M}, efConstruction={settings.HNSW_EF_CONSTRUCTION}, efSearch={settings.HNSW_EF_SEARCH}")
|
||
|
||
elif settings.FAISS_INDEX_TYPE == "FlatIP" and self.use_cosine:
|
||
self.index = faiss.IndexFlatIP(self.dim)
|
||
print("✅ 创建 FlatIP 索引(精确余弦)")
|
||
else:
|
||
# 默认精确 L2(兼容旧配置)
|
||
self.index = faiss.IndexFlatL2(self.dim)
|
||
print("✅ 创建 FlatL2 索引(精确欧式)")
|
||
|
||
def _normalize(self, embedding: list[float]) -> np.ndarray:
|
||
"""L2 归一化(余弦相似度必需)"""
|
||
vec = np.array(embedding, dtype=np.float32)
|
||
norm = np.linalg.norm(vec)
|
||
return vec / norm if norm > 0 else vec
|
||
|
||
def insert(self, embedding: list[float]) -> int:
|
||
"""插入向量,返回 ID"""
|
||
if len(embedding) != self.dim:
|
||
raise ValueError(f"Embedding 维度错误,应为 {self.dim}")
|
||
|
||
vec = self._normalize(embedding) if self.use_cosine else np.array(embedding, dtype=np.float32)
|
||
vec = vec.reshape(1, -1)
|
||
|
||
idx = self.index.ntotal
|
||
self.index.add(vec)
|
||
return idx
|
||
|
||
def search(self, embedding: list[float], k: int = 5):
|
||
"""搜索相似向量(返回 id + 距离)"""
|
||
if len(embedding) != self.dim:
|
||
raise ValueError(f"Embedding 维度错误,应为 {self.dim}")
|
||
|
||
vec = self._normalize(embedding) if self.use_cosine else np.array(embedding, dtype=np.float32)
|
||
vec = vec.reshape(1, -1)
|
||
|
||
distances, indices = self.index.search(vec, k)
|
||
|
||
return {
|
||
"ids": indices[0].tolist(),
|
||
"distances": distances[0].tolist(), # 余弦时:值越大越相似(1.0=完全相同)
|
||
"similarity_scores": [1 - d for d in distances[0].tolist()] if not self.use_cosine else distances[0].tolist()
|
||
}
|
||
|
||
def persist(self):
|
||
"""保存索引"""
|
||
faiss.write_index(self.index, self.index_path)
|
||
print(f"💾 索引已保存 → {self.index_path}(共 {self.index.ntotal} 个向量)")
|
||
return True
|
||
|
||
|
||
# 单例
|
||
faiss_manager = FaissManager()
|