faiss server
This commit is contained in:
87
faiss/faiss_manager.py
Normal file
87
faiss/faiss_manager.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# faiss_manager.py
|
||||
import os
|
||||
import numpy as np
|
||||
import faiss
|
||||
from config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
class FaissManager:
|
||||
def __init__(self):
|
||||
self.dim = settings.FAISS_DIM
|
||||
self.index_path = settings.FAISS_INDEX_PATH
|
||||
self.use_cosine = settings.USE_COSINE_SIMILARITY
|
||||
self.index = None
|
||||
self._load_or_create_index()
|
||||
|
||||
def _load_or_create_index(self):
|
||||
if os.path.exists(self.index_path):
|
||||
self.index = faiss.read_index(self.index_path)
|
||||
print(f"✅ 加载已有索引:{self.index.ntotal} 个向量,维度={self.index.d}")
|
||||
return
|
||||
|
||||
# 创建新索引
|
||||
if settings.FAISS_INDEX_TYPE == "HNSW":
|
||||
if self.use_cosine:
|
||||
self.index = faiss.IndexHNSWFlat(self.dim, settings.HNSW_M, faiss.METRIC_INNER_PRODUCT)
|
||||
print("✅ 创建 HNSWIP 索引(余弦相似度)")
|
||||
else:
|
||||
self.index = faiss.IndexHNSWFlat(self.dim, settings.HNSW_M)
|
||||
print("✅ 创建 HNSWFlat 索引(L2 距离)")
|
||||
|
||||
# 设置 HNSW 参数
|
||||
self.index.hnsw.efConstruction = settings.HNSW_EF_CONSTRUCTION
|
||||
self.index.hnsw.efSearch = settings.HNSW_EF_SEARCH
|
||||
print(f" HNSW 参数: M={settings.HNSW_M}, efConstruction={settings.HNSW_EF_CONSTRUCTION}, efSearch={settings.HNSW_EF_SEARCH}")
|
||||
|
||||
elif settings.FAISS_INDEX_TYPE == "FlatIP" and self.use_cosine:
|
||||
self.index = faiss.IndexFlatIP(self.dim)
|
||||
print("✅ 创建 FlatIP 索引(精确余弦)")
|
||||
else:
|
||||
# 默认精确 L2(兼容旧配置)
|
||||
self.index = faiss.IndexFlatL2(self.dim)
|
||||
print("✅ 创建 FlatL2 索引(精确欧式)")
|
||||
|
||||
def _normalize(self, embedding: list[float]) -> np.ndarray:
|
||||
"""L2 归一化(余弦相似度必需)"""
|
||||
vec = np.array(embedding, dtype=np.float32)
|
||||
norm = np.linalg.norm(vec)
|
||||
return vec / norm if norm > 0 else vec
|
||||
|
||||
def insert(self, embedding: list[float]) -> int:
|
||||
"""插入向量,返回 ID"""
|
||||
if len(embedding) != self.dim:
|
||||
raise ValueError(f"Embedding 维度错误,应为 {self.dim}")
|
||||
|
||||
vec = self._normalize(embedding) if self.use_cosine else np.array(embedding, dtype=np.float32)
|
||||
vec = vec.reshape(1, -1)
|
||||
|
||||
idx = self.index.ntotal
|
||||
self.index.add(vec)
|
||||
return idx
|
||||
|
||||
def search(self, embedding: list[float], k: int = 5):
|
||||
"""搜索相似向量(返回 id + 距离)"""
|
||||
if len(embedding) != self.dim:
|
||||
raise ValueError(f"Embedding 维度错误,应为 {self.dim}")
|
||||
|
||||
vec = self._normalize(embedding) if self.use_cosine else np.array(embedding, dtype=np.float32)
|
||||
vec = vec.reshape(1, -1)
|
||||
|
||||
distances, indices = self.index.search(vec, k)
|
||||
|
||||
return {
|
||||
"ids": indices[0].tolist(),
|
||||
"distances": distances[0].tolist(), # 余弦时:值越大越相似(1.0=完全相同)
|
||||
"similarity_scores": [1 - d for d in distances[0].tolist()] if not self.use_cosine else distances[0].tolist()
|
||||
}
|
||||
|
||||
def persist(self):
|
||||
"""保存索引"""
|
||||
faiss.write_index(self.index, self.index_path)
|
||||
print(f"💾 索引已保存 → {self.index_path}(共 {self.index.ntotal} 个向量)")
|
||||
return True
|
||||
|
||||
|
||||
# 单例
|
||||
faiss_manager = FaissManager()
|
||||
Reference in New Issue
Block a user