提供对话上下文索引
This commit is contained in:
7
faiss/.gitignore
vendored
7
faiss/.gitignore
vendored
@@ -1,3 +1,6 @@
|
||||
faiss_index.bin
|
||||
indexes/*
|
||||
!indexes/.gitkeep
|
||||
!indexes/global/.gitkeep
|
||||
!indexes/conversations/.gitkeep
|
||||
.vscode
|
||||
__pycache__
|
||||
__pycache__
|
||||
|
||||
68
faiss/api.py
68
faiss/api.py
@@ -1,41 +1,81 @@
|
||||
# api.py
|
||||
from fastapi import FastAPI, Depends, HTTPException
|
||||
from models import EmbeddingInput, SearchInput
|
||||
from faiss_manager import faiss_manager
|
||||
from fastapi import FastAPI, HTTPException
|
||||
|
||||
from config import get_settings
|
||||
from faiss_manager import faiss_manager
|
||||
from models import EmbeddingInput, IndexDeleteInput, IndexInsertInput, IndexSearchInput, SearchInput
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
app = FastAPI(
|
||||
title="FAISS 服务",
|
||||
description="向量插入 + 相似搜索 + 持久化",
|
||||
version="1.0.0"
|
||||
description="向量插入 + 相似搜索 + 多索引管理",
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
|
||||
def log_business(message: str):
|
||||
if settings.ENABLE_REQUEST_LOGS:
|
||||
print(message, flush=True)
|
||||
|
||||
|
||||
@app.post("/insert")
|
||||
async def insert(data: EmbeddingInput):
|
||||
try:
|
||||
vector_id = faiss_manager.insert(data.embedding)
|
||||
log_business(f"[faiss] insert id={vector_id}")
|
||||
log_business(f"[faiss] global insert id={vector_id}")
|
||||
return {"id": vector_id}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
|
||||
@app.post("/search")
|
||||
async def search(data: SearchInput):
|
||||
try:
|
||||
result = faiss_manager.search(data.embedding, data.k)
|
||||
log_business(
|
||||
f"[faiss] search ids={result['ids']} similarity_scores={result['similarity_scores']}",
|
||||
)
|
||||
log_business(f"[faiss] global search ids={result['ids']} similarity_scores={result['similarity_scores']}")
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
|
||||
@app.post("/create")
|
||||
async def create():
|
||||
try:
|
||||
idx = faiss_manager.create()
|
||||
log_business(f"[faiss] create idx={idx}")
|
||||
return {"idx": idx}
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
|
||||
@app.post("/insert_idx")
|
||||
async def insert_idx(data: IndexInsertInput):
|
||||
try:
|
||||
vector_id = faiss_manager.insert_idx(data.idx, data.id, data.embedding)
|
||||
log_business(f"[faiss] insert_idx idx={data.idx} id={vector_id}")
|
||||
return {"id": vector_id, "idx": data.idx}
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
|
||||
@app.post("/search_idx")
|
||||
async def search_idx(data: IndexSearchInput):
|
||||
try:
|
||||
result = faiss_manager.search_idx(data.idx, data.embedding, data.k)
|
||||
log_business(f"[faiss] search_idx idx={data.idx} ids={result['ids']} similarity_scores={result['similarity_scores']}")
|
||||
return result
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
|
||||
@app.post("/del")
|
||||
async def delete_idx(data: IndexDeleteInput):
|
||||
try:
|
||||
faiss_manager.delete(data.idx)
|
||||
log_business(f"[faiss] delete idx={data.idx}")
|
||||
return {"status": "success", "idx": data.idx}
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
|
||||
@app.post("/persist")
|
||||
|
||||
@@ -5,6 +5,10 @@ from functools import lru_cache
|
||||
class Settings(BaseSettings):
|
||||
# FAISS 配置(已优化为你的 LLM 相似问题缓存场景)
|
||||
FAISS_DIM: int = 1024 # 根据你的 embedding 模型修改(e.g. bge-large=1024, text-embedding-3-large=3072)
|
||||
FAISS_INDEX_DIR: str = "indexes"
|
||||
FAISS_GLOBAL_DIR: str = "global"
|
||||
FAISS_CONVERSATION_DIR: str = "conversations"
|
||||
FAISS_GLOBAL_INDEX_NAME: str = "global_qa.index"
|
||||
FAISS_INDEX_PATH: str = "faiss_index.bin"
|
||||
FAISS_INDEX_TYPE: str = "HNSW" # 默认改为 HNSW(最推荐)
|
||||
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
services:
|
||||
faiss:
|
||||
build: .
|
||||
container_name: faiss-service
|
||||
ports:
|
||||
- "8451:8000"
|
||||
volumes:
|
||||
- ./faiss_index.bin:/app/faiss_index.bin # 持久化索引文件
|
||||
- ./.env:/app/.env # 可选:挂载配置
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- FAISS_DIM=1024
|
||||
- APP_PORT=8000
|
||||
@@ -1,87 +1,199 @@
|
||||
# faiss_manager.py
|
||||
import os
|
||||
import numpy as np
|
||||
import uuid
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
|
||||
from config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class FaissManager:
|
||||
def __init__(self):
|
||||
self.dim = settings.FAISS_DIM
|
||||
self.index_path = settings.FAISS_INDEX_PATH
|
||||
self.index_dir = settings.FAISS_INDEX_DIR
|
||||
self.global_dir = os.path.join(self.index_dir, settings.FAISS_GLOBAL_DIR)
|
||||
self.conversation_dir = os.path.join(self.index_dir, settings.FAISS_CONVERSATION_DIR)
|
||||
self.global_index_path = os.path.join(self.global_dir, settings.FAISS_GLOBAL_INDEX_NAME)
|
||||
self.legacy_index_path = settings.FAISS_INDEX_PATH
|
||||
self.use_cosine = settings.USE_COSINE_SIMILARITY
|
||||
self.index = None
|
||||
self._load_or_create_index()
|
||||
self.global_index = None
|
||||
self.conversation_indexes = {}
|
||||
self._load_indexes()
|
||||
|
||||
def _load_or_create_index(self):
|
||||
if os.path.exists(self.index_path):
|
||||
self.index = faiss.read_index(self.index_path)
|
||||
print(f"✅ 加载已有索引:{self.index.ntotal} 个向量,维度={self.index.d}")
|
||||
return
|
||||
def _load_indexes(self):
|
||||
os.makedirs(self.global_dir, exist_ok=True)
|
||||
os.makedirs(self.conversation_dir, exist_ok=True)
|
||||
|
||||
if os.path.isfile(self.global_index_path):
|
||||
self.global_index = self._load_index_from_path(self.global_index_path)
|
||||
print(f"✅ 加载全局索引:{self.global_index.ntotal} 个向量")
|
||||
elif os.path.isfile(self.legacy_index_path):
|
||||
legacy_index = faiss.read_index(self.legacy_index_path)
|
||||
self.global_index = self._ensure_id_map(legacy_index)
|
||||
self._persist_index(self.global_index, self.global_index_path)
|
||||
print(f"✅ 从旧索引迁移全局索引:{self.global_index.ntotal} 个向量")
|
||||
else:
|
||||
self.global_index = self._create_index("global")
|
||||
self._persist_index(self.global_index, self.global_index_path)
|
||||
print("✅ 创建全局索引")
|
||||
|
||||
for file_name in sorted(os.listdir(self.conversation_dir)):
|
||||
if not file_name.endswith(".index"):
|
||||
continue
|
||||
idx = file_name[: -len(".index")]
|
||||
index_path = self._conversation_index_path(idx)
|
||||
self.conversation_indexes[idx] = self._load_index_from_path(index_path)
|
||||
if self.conversation_indexes:
|
||||
print(f"✅ 加载对话索引:{len(self.conversation_indexes)} 个")
|
||||
|
||||
def _create_base_index(self, kind: str):
|
||||
if kind == "conversation":
|
||||
if self.use_cosine:
|
||||
return faiss.IndexFlatIP(self.dim)
|
||||
return faiss.IndexFlatL2(self.dim)
|
||||
|
||||
# 创建新索引
|
||||
if settings.FAISS_INDEX_TYPE == "HNSW":
|
||||
if self.use_cosine:
|
||||
self.index = faiss.IndexHNSWFlat(self.dim, settings.HNSW_M, faiss.METRIC_INNER_PRODUCT)
|
||||
print("✅ 创建 HNSWIP 索引(余弦相似度)")
|
||||
index = faiss.IndexHNSWFlat(self.dim, settings.HNSW_M, faiss.METRIC_INNER_PRODUCT)
|
||||
else:
|
||||
self.index = faiss.IndexHNSWFlat(self.dim, settings.HNSW_M)
|
||||
print("✅ 创建 HNSWFlat 索引(L2 距离)")
|
||||
|
||||
# 设置 HNSW 参数
|
||||
self.index.hnsw.efConstruction = settings.HNSW_EF_CONSTRUCTION
|
||||
self.index.hnsw.efSearch = settings.HNSW_EF_SEARCH
|
||||
print(f" HNSW 参数: M={settings.HNSW_M}, efConstruction={settings.HNSW_EF_CONSTRUCTION}, efSearch={settings.HNSW_EF_SEARCH}")
|
||||
index = faiss.IndexHNSWFlat(self.dim, settings.HNSW_M)
|
||||
index.hnsw.efConstruction = settings.HNSW_EF_CONSTRUCTION
|
||||
index.hnsw.efSearch = settings.HNSW_EF_SEARCH
|
||||
return index
|
||||
if settings.FAISS_INDEX_TYPE == "FlatIP" and self.use_cosine:
|
||||
return faiss.IndexFlatIP(self.dim)
|
||||
return faiss.IndexFlatL2(self.dim)
|
||||
|
||||
elif settings.FAISS_INDEX_TYPE == "FlatIP" and self.use_cosine:
|
||||
self.index = faiss.IndexFlatIP(self.dim)
|
||||
print("✅ 创建 FlatIP 索引(精确余弦)")
|
||||
else:
|
||||
# 默认精确 L2(兼容旧配置)
|
||||
self.index = faiss.IndexFlatL2(self.dim)
|
||||
print("✅ 创建 FlatL2 索引(精确欧式)")
|
||||
def _create_index(self, kind: str):
|
||||
return faiss.IndexIDMap2(self._create_base_index(kind))
|
||||
|
||||
def _normalize(self, embedding: list[float]) -> np.ndarray:
|
||||
"""L2 归一化(余弦相似度必需)"""
|
||||
vec = np.array(embedding, dtype=np.float32)
|
||||
norm = np.linalg.norm(vec)
|
||||
return vec / norm if norm > 0 else vec
|
||||
|
||||
def insert(self, embedding: list[float]) -> int:
|
||||
"""插入向量,返回 ID"""
|
||||
def _prepare_embedding(self, embedding: list[float]) -> np.ndarray:
|
||||
if len(embedding) != self.dim:
|
||||
raise ValueError(f"Embedding 维度错误,应为 {self.dim}")
|
||||
|
||||
vec = self._normalize(embedding) if self.use_cosine else np.array(embedding, dtype=np.float32)
|
||||
vec = vec.reshape(1, -1)
|
||||
|
||||
idx = self.index.ntotal
|
||||
self.index.add(vec)
|
||||
return vec.reshape(1, -1)
|
||||
|
||||
def _conversation_index_path(self, idx: str) -> str:
|
||||
return os.path.join(self.conversation_dir, f"{idx}.index")
|
||||
|
||||
def _persist_index(self, index, path: str):
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
faiss.write_index(index, path)
|
||||
|
||||
def _load_index_from_path(self, path: str):
|
||||
index = faiss.read_index(path)
|
||||
return self._ensure_id_map(index)
|
||||
|
||||
def _ensure_id_map(self, index):
|
||||
if hasattr(index, "id_map"):
|
||||
return index
|
||||
|
||||
converted = faiss.IndexIDMap2(self._create_base_index("global"))
|
||||
if index.ntotal == 0:
|
||||
return converted
|
||||
|
||||
vectors = np.empty((index.ntotal, index.d), dtype=np.float32)
|
||||
index.reconstruct_n(0, index.ntotal, vectors)
|
||||
ids = np.arange(index.ntotal, dtype=np.int64)
|
||||
converted.add_with_ids(vectors, ids)
|
||||
return converted
|
||||
|
||||
def _index_ids(self, index) -> np.ndarray:
|
||||
if not hasattr(index, "id_map"):
|
||||
return np.arange(index.ntotal, dtype=np.int64)
|
||||
return faiss.vector_to_array(index.id_map)
|
||||
|
||||
def _next_global_id(self) -> int:
|
||||
ids = self._index_ids(self.global_index)
|
||||
if ids.size == 0:
|
||||
return 0
|
||||
return int(ids.max()) + 1
|
||||
|
||||
def _get_conversation_index(self, idx: str):
|
||||
if idx in self.conversation_indexes:
|
||||
return self.conversation_indexes[idx]
|
||||
|
||||
index_path = self._conversation_index_path(idx)
|
||||
if not os.path.isfile(index_path):
|
||||
raise ValueError(f"索引不存在: {idx}")
|
||||
index = self._load_index_from_path(index_path)
|
||||
self.conversation_indexes[idx] = index
|
||||
return index
|
||||
|
||||
def create(self) -> str:
|
||||
idx = f"conv_{uuid.uuid4().hex}"
|
||||
index = self._create_index("conversation")
|
||||
self.conversation_indexes[idx] = index
|
||||
self._persist_index(index, self._conversation_index_path(idx))
|
||||
return idx
|
||||
|
||||
def search(self, embedding: list[float], k: int = 5):
|
||||
"""搜索相似向量(返回 id + 距离)"""
|
||||
if len(embedding) != self.dim:
|
||||
raise ValueError(f"Embedding 维度错误,应为 {self.dim}")
|
||||
def insert(self, embedding: list[float]) -> int:
|
||||
vector_id = self._next_global_id()
|
||||
self.insert_global_with_id(vector_id, embedding)
|
||||
return vector_id
|
||||
|
||||
vec = self._normalize(embedding) if self.use_cosine else np.array(embedding, dtype=np.float32)
|
||||
vec = vec.reshape(1, -1)
|
||||
def insert_global_with_id(self, vector_id: int, embedding: list[float]) -> int:
|
||||
vec = self._prepare_embedding(embedding)
|
||||
ids = np.array([vector_id], dtype=np.int64)
|
||||
self.global_index.add_with_ids(vec, ids)
|
||||
self._persist_index(self.global_index, self.global_index_path)
|
||||
return vector_id
|
||||
|
||||
distances, indices = self.index.search(vec, k)
|
||||
|
||||
def insert_idx(self, idx: str, vector_id: int, embedding: list[float]) -> int:
|
||||
index = self._get_conversation_index(idx)
|
||||
vec = self._prepare_embedding(embedding)
|
||||
ids = np.array([vector_id], dtype=np.int64)
|
||||
index.add_with_ids(vec, ids)
|
||||
self._persist_index(index, self._conversation_index_path(idx))
|
||||
return vector_id
|
||||
|
||||
def _search(self, index, embedding: list[float], k: int = 5):
|
||||
vec = self._prepare_embedding(embedding)
|
||||
distances, indices = index.search(vec, k)
|
||||
ids = indices[0].tolist()
|
||||
raw_distances = distances[0].tolist()
|
||||
normalized_distances = []
|
||||
similarity_scores = []
|
||||
for vector_id, distance in zip(ids, raw_distances):
|
||||
if vector_id < 0:
|
||||
normalized_distances.append(0.0)
|
||||
similarity_scores.append(0.0)
|
||||
continue
|
||||
normalized_distances.append(distance)
|
||||
similarity_scores.append(1 - distance if not self.use_cosine else distance)
|
||||
return {
|
||||
"ids": indices[0].tolist(),
|
||||
"distances": distances[0].tolist(), # 余弦时:值越大越相似(1.0=完全相同)
|
||||
"similarity_scores": [1 - d for d in distances[0].tolist()] if not self.use_cosine else distances[0].tolist()
|
||||
"ids": ids,
|
||||
"distances": normalized_distances,
|
||||
"similarity_scores": similarity_scores,
|
||||
}
|
||||
|
||||
def search(self, embedding: list[float], k: int = 5):
|
||||
return self._search(self.global_index, embedding, k)
|
||||
|
||||
def search_idx(self, idx: str, embedding: list[float], k: int = 5):
|
||||
index = self._get_conversation_index(idx)
|
||||
return self._search(index, embedding, k)
|
||||
|
||||
def delete(self, idx: str) -> bool:
|
||||
self._get_conversation_index(idx)
|
||||
self.conversation_indexes.pop(idx, None)
|
||||
index_path = self._conversation_index_path(idx)
|
||||
if os.path.exists(index_path):
|
||||
os.remove(index_path)
|
||||
return True
|
||||
|
||||
def persist(self):
|
||||
"""保存索引"""
|
||||
faiss.write_index(self.index, self.index_path)
|
||||
print(f"💾 索引已保存 → {self.index_path}(共 {self.index.ntotal} 个向量)")
|
||||
self._persist_index(self.global_index, self.global_index_path)
|
||||
for idx, index in self.conversation_indexes.items():
|
||||
self._persist_index(index, self._conversation_index_path(idx))
|
||||
return True
|
||||
|
||||
|
||||
# 单例
|
||||
faiss_manager = FaissManager()
|
||||
|
||||
1
faiss/indexes/.gitkeep
Normal file
1
faiss/indexes/.gitkeep
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
@@ -1,10 +1,28 @@
|
||||
# models.py
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class EmbeddingInput(BaseModel):
|
||||
embedding: List[float]
|
||||
|
||||
|
||||
class SearchInput(BaseModel):
|
||||
embedding: List[float]
|
||||
k: int = 5
|
||||
k: int = 5
|
||||
|
||||
|
||||
class IndexInsertInput(BaseModel):
|
||||
idx: str
|
||||
id: int
|
||||
embedding: List[float]
|
||||
|
||||
|
||||
class IndexSearchInput(BaseModel):
|
||||
idx: str
|
||||
embedding: List[float]
|
||||
k: int = 5
|
||||
|
||||
|
||||
class IndexDeleteInput(BaseModel):
|
||||
idx: str
|
||||
|
||||
Reference in New Issue
Block a user