提供对话上下文索引

This commit is contained in:
2026-04-10 12:47:39 +00:00
parent 8e39e609cc
commit d4e6142131
10 changed files with 267 additions and 90 deletions

View File

@@ -17,9 +17,9 @@ chat:
bot_desc: "你是一个AI助手我需要你模拟一名资深的软件工程师来回答我的问题" bot_desc: "你是一个AI助手我需要你模拟一名资深的软件工程师来回答我的问题"
min_response_tokens: 600 min_response_tokens: 600
redis: redis:
host: "host.docker.internal" host: "redis"
port: 8888 port: 6379
pwd: "123456" pwd: ""
dependOn: dependOn:
sensitive: sensitive:
address: "sensitive-filter:50053" address: "sensitive-filter:50053"
@@ -36,7 +36,7 @@ embedding:
model: "embedding-2" model: "embedding-2"
timeout: 10 timeout: 10
faiss: faiss:
base_url: "http://host.docker.internal:8451" base_url: "http://faiss:8000"
search_k: 1 search_k: 1
similarity_threshold: 0.9 similarity_threshold: 0.9
timeout: 10 timeout: 10

View File

@@ -1,4 +1,18 @@
services: services:
faiss:
build:
context: ../faiss
container_name: faiss-service
ports:
- "8451:8000"
volumes:
- ../faiss/indexes:/app/indexes # 持久化索引目录
- ../faiss/.env:/app/.env # 可选:挂载配置
restart: unless-stopped
environment:
- FAISS_DIM=1024
- APP_PORT=8000
redis: redis:
image: redis:7-alpine image: redis:7-alpine
container_name: ai-chat-redis container_name: ai-chat-redis
@@ -60,12 +74,10 @@ services:
- .env - .env
volumes: volumes:
- ./configs/ai-chat-service.yaml:/app/config.yaml:ro - ./configs/ai-chat-service.yaml:/app/config.yaml:ro
extra_hosts:
- "host.docker.internal:host-gateway"
ports: ports:
- "50055:50055" - "50055:50055"
depends_on: depends_on:
- ai-chat-redis - faiss
- tokenizer - tokenizer
- sensitive-filter - sensitive-filter
- keywords-filter - keywords-filter

View File

@@ -36,7 +36,7 @@ embedding:
model: "embedding-2" model: "embedding-2"
timeout: 10 timeout: 10
faiss: faiss:
base_url: "http://host.docker.internal:8451" base_url: "http://faiss:8000"
search_k: 1 search_k: 1
similarity_threshold: 0.9 similarity_threshold: 0.9
timeout: 10 timeout: 10

5
faiss/.gitignore vendored
View File

@@ -1,3 +1,6 @@
faiss_index.bin indexes/*
!indexes/.gitkeep
!indexes/global/.gitkeep
!indexes/conversations/.gitkeep
.vscode .vscode
__pycache__ __pycache__

View File

@@ -1,41 +1,81 @@
# api.py from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, Depends, HTTPException
from models import EmbeddingInput, SearchInput
from faiss_manager import faiss_manager
from config import get_settings from config import get_settings
from faiss_manager import faiss_manager
from models import EmbeddingInput, IndexDeleteInput, IndexInsertInput, IndexSearchInput, SearchInput
settings = get_settings() settings = get_settings()
app = FastAPI( app = FastAPI(
title="FAISS 服务", title="FAISS 服务",
description="向量插入 + 相似搜索 + 持久化", description="向量插入 + 相似搜索 + 多索引管理",
version="1.0.0" version="1.0.0",
) )
def log_business(message: str): def log_business(message: str):
if settings.ENABLE_REQUEST_LOGS: if settings.ENABLE_REQUEST_LOGS:
print(message, flush=True) print(message, flush=True)
@app.post("/insert") @app.post("/insert")
async def insert(data: EmbeddingInput): async def insert(data: EmbeddingInput):
try: try:
vector_id = faiss_manager.insert(data.embedding) vector_id = faiss_manager.insert(data.embedding)
log_business(f"[faiss] insert id={vector_id}") log_business(f"[faiss] global insert id={vector_id}")
return {"id": vector_id} return {"id": vector_id}
except Exception as e: except Exception as exc:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(exc))
@app.post("/search") @app.post("/search")
async def search(data: SearchInput): async def search(data: SearchInput):
try: try:
result = faiss_manager.search(data.embedding, data.k) result = faiss_manager.search(data.embedding, data.k)
log_business( log_business(f"[faiss] global search ids={result['ids']} similarity_scores={result['similarity_scores']}")
f"[faiss] search ids={result['ids']} similarity_scores={result['similarity_scores']}",
)
return result return result
except Exception as e: except Exception as exc:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(exc))
@app.post("/create")
async def create():
try:
idx = faiss_manager.create()
log_business(f"[faiss] create idx={idx}")
return {"idx": idx}
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc))
@app.post("/insert_idx")
async def insert_idx(data: IndexInsertInput):
try:
vector_id = faiss_manager.insert_idx(data.idx, data.id, data.embedding)
log_business(f"[faiss] insert_idx idx={data.idx} id={vector_id}")
return {"id": vector_id, "idx": data.idx}
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc))
@app.post("/search_idx")
async def search_idx(data: IndexSearchInput):
try:
result = faiss_manager.search_idx(data.idx, data.embedding, data.k)
log_business(f"[faiss] search_idx idx={data.idx} ids={result['ids']} similarity_scores={result['similarity_scores']}")
return result
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc))
@app.post("/del")
async def delete_idx(data: IndexDeleteInput):
try:
faiss_manager.delete(data.idx)
log_business(f"[faiss] delete idx={data.idx}")
return {"status": "success", "idx": data.idx}
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc))
@app.post("/persist") @app.post("/persist")

View File

@@ -5,6 +5,10 @@ from functools import lru_cache
class Settings(BaseSettings): class Settings(BaseSettings):
# FAISS 配置(已优化为你的 LLM 相似问题缓存场景) # FAISS 配置(已优化为你的 LLM 相似问题缓存场景)
FAISS_DIM: int = 1024 # 根据你的 embedding 模型修改e.g. bge-large=1024, text-embedding-3-large=3072 FAISS_DIM: int = 1024 # 根据你的 embedding 模型修改e.g. bge-large=1024, text-embedding-3-large=3072
FAISS_INDEX_DIR: str = "indexes"
FAISS_GLOBAL_DIR: str = "global"
FAISS_CONVERSATION_DIR: str = "conversations"
FAISS_GLOBAL_INDEX_NAME: str = "global_qa.index"
FAISS_INDEX_PATH: str = "faiss_index.bin" FAISS_INDEX_PATH: str = "faiss_index.bin"
FAISS_INDEX_TYPE: str = "HNSW" # 默认改为 HNSW最推荐 FAISS_INDEX_TYPE: str = "HNSW" # 默认改为 HNSW最推荐

View File

@@ -1,13 +0,0 @@
services:
faiss:
build: .
container_name: faiss-service
ports:
- "8451:8000"
volumes:
- ./faiss_index.bin:/app/faiss_index.bin # 持久化索引文件
- ./.env:/app/.env # 可选:挂载配置
restart: unless-stopped
environment:
- FAISS_DIM=1024
- APP_PORT=8000

View File

@@ -1,87 +1,199 @@
# faiss_manager.py
import os import os
import numpy as np import uuid
import faiss import faiss
import numpy as np
from config import get_settings from config import get_settings
settings = get_settings() settings = get_settings()
class FaissManager: class FaissManager:
def __init__(self): def __init__(self):
self.dim = settings.FAISS_DIM self.dim = settings.FAISS_DIM
self.index_path = settings.FAISS_INDEX_PATH self.index_dir = settings.FAISS_INDEX_DIR
self.global_dir = os.path.join(self.index_dir, settings.FAISS_GLOBAL_DIR)
self.conversation_dir = os.path.join(self.index_dir, settings.FAISS_CONVERSATION_DIR)
self.global_index_path = os.path.join(self.global_dir, settings.FAISS_GLOBAL_INDEX_NAME)
self.legacy_index_path = settings.FAISS_INDEX_PATH
self.use_cosine = settings.USE_COSINE_SIMILARITY self.use_cosine = settings.USE_COSINE_SIMILARITY
self.index = None self.global_index = None
self._load_or_create_index() self.conversation_indexes = {}
self._load_indexes()
def _load_or_create_index(self): def _load_indexes(self):
if os.path.exists(self.index_path): os.makedirs(self.global_dir, exist_ok=True)
self.index = faiss.read_index(self.index_path) os.makedirs(self.conversation_dir, exist_ok=True)
print(f"✅ 加载已有索引:{self.index.ntotal} 个向量,维度={self.index.d}")
return if os.path.isfile(self.global_index_path):
self.global_index = self._load_index_from_path(self.global_index_path)
print(f"✅ 加载全局索引:{self.global_index.ntotal} 个向量")
elif os.path.isfile(self.legacy_index_path):
legacy_index = faiss.read_index(self.legacy_index_path)
self.global_index = self._ensure_id_map(legacy_index)
self._persist_index(self.global_index, self.global_index_path)
print(f"✅ 从旧索引迁移全局索引:{self.global_index.ntotal} 个向量")
else:
self.global_index = self._create_index("global")
self._persist_index(self.global_index, self.global_index_path)
print("✅ 创建全局索引")
for file_name in sorted(os.listdir(self.conversation_dir)):
if not file_name.endswith(".index"):
continue
idx = file_name[: -len(".index")]
index_path = self._conversation_index_path(idx)
self.conversation_indexes[idx] = self._load_index_from_path(index_path)
if self.conversation_indexes:
print(f"✅ 加载对话索引:{len(self.conversation_indexes)}")
def _create_base_index(self, kind: str):
if kind == "conversation":
if self.use_cosine:
return faiss.IndexFlatIP(self.dim)
return faiss.IndexFlatL2(self.dim)
# 创建新索引
if settings.FAISS_INDEX_TYPE == "HNSW": if settings.FAISS_INDEX_TYPE == "HNSW":
if self.use_cosine: if self.use_cosine:
self.index = faiss.IndexHNSWFlat(self.dim, settings.HNSW_M, faiss.METRIC_INNER_PRODUCT) index = faiss.IndexHNSWFlat(self.dim, settings.HNSW_M, faiss.METRIC_INNER_PRODUCT)
print("✅ 创建 HNSWIP 索引(余弦相似度)")
else: else:
self.index = faiss.IndexHNSWFlat(self.dim, settings.HNSW_M) index = faiss.IndexHNSWFlat(self.dim, settings.HNSW_M)
print("✅ 创建 HNSWFlat 索引L2 距离)") index.hnsw.efConstruction = settings.HNSW_EF_CONSTRUCTION
index.hnsw.efSearch = settings.HNSW_EF_SEARCH
return index
if settings.FAISS_INDEX_TYPE == "FlatIP" and self.use_cosine:
return faiss.IndexFlatIP(self.dim)
return faiss.IndexFlatL2(self.dim)
# 设置 HNSW 参数 def _create_index(self, kind: str):
self.index.hnsw.efConstruction = settings.HNSW_EF_CONSTRUCTION return faiss.IndexIDMap2(self._create_base_index(kind))
self.index.hnsw.efSearch = settings.HNSW_EF_SEARCH
print(f" HNSW 参数: M={settings.HNSW_M}, efConstruction={settings.HNSW_EF_CONSTRUCTION}, efSearch={settings.HNSW_EF_SEARCH}")
elif settings.FAISS_INDEX_TYPE == "FlatIP" and self.use_cosine:
self.index = faiss.IndexFlatIP(self.dim)
print("✅ 创建 FlatIP 索引(精确余弦)")
else:
# 默认精确 L2兼容旧配置
self.index = faiss.IndexFlatL2(self.dim)
print("✅ 创建 FlatL2 索引(精确欧式)")
def _normalize(self, embedding: list[float]) -> np.ndarray: def _normalize(self, embedding: list[float]) -> np.ndarray:
"""L2 归一化(余弦相似度必需)"""
vec = np.array(embedding, dtype=np.float32) vec = np.array(embedding, dtype=np.float32)
norm = np.linalg.norm(vec) norm = np.linalg.norm(vec)
return vec / norm if norm > 0 else vec return vec / norm if norm > 0 else vec
def insert(self, embedding: list[float]) -> int: def _prepare_embedding(self, embedding: list[float]) -> np.ndarray:
"""插入向量,返回 ID"""
if len(embedding) != self.dim: if len(embedding) != self.dim:
raise ValueError(f"Embedding 维度错误,应为 {self.dim}") raise ValueError(f"Embedding 维度错误,应为 {self.dim}")
vec = self._normalize(embedding) if self.use_cosine else np.array(embedding, dtype=np.float32) vec = self._normalize(embedding) if self.use_cosine else np.array(embedding, dtype=np.float32)
vec = vec.reshape(1, -1) return vec.reshape(1, -1)
idx = self.index.ntotal def _conversation_index_path(self, idx: str) -> str:
self.index.add(vec) return os.path.join(self.conversation_dir, f"{idx}.index")
def _persist_index(self, index, path: str):
os.makedirs(os.path.dirname(path), exist_ok=True)
faiss.write_index(index, path)
def _load_index_from_path(self, path: str):
index = faiss.read_index(path)
return self._ensure_id_map(index)
def _ensure_id_map(self, index):
if hasattr(index, "id_map"):
return index
converted = faiss.IndexIDMap2(self._create_base_index("global"))
if index.ntotal == 0:
return converted
vectors = np.empty((index.ntotal, index.d), dtype=np.float32)
index.reconstruct_n(0, index.ntotal, vectors)
ids = np.arange(index.ntotal, dtype=np.int64)
converted.add_with_ids(vectors, ids)
return converted
def _index_ids(self, index) -> np.ndarray:
if not hasattr(index, "id_map"):
return np.arange(index.ntotal, dtype=np.int64)
return faiss.vector_to_array(index.id_map)
def _next_global_id(self) -> int:
ids = self._index_ids(self.global_index)
if ids.size == 0:
return 0
return int(ids.max()) + 1
def _get_conversation_index(self, idx: str):
if idx in self.conversation_indexes:
return self.conversation_indexes[idx]
index_path = self._conversation_index_path(idx)
if not os.path.isfile(index_path):
raise ValueError(f"索引不存在: {idx}")
index = self._load_index_from_path(index_path)
self.conversation_indexes[idx] = index
return index
def create(self) -> str:
idx = f"conv_{uuid.uuid4().hex}"
index = self._create_index("conversation")
self.conversation_indexes[idx] = index
self._persist_index(index, self._conversation_index_path(idx))
return idx return idx
def search(self, embedding: list[float], k: int = 5): def insert(self, embedding: list[float]) -> int:
"""搜索相似向量(返回 id + 距离)""" vector_id = self._next_global_id()
if len(embedding) != self.dim: self.insert_global_with_id(vector_id, embedding)
raise ValueError(f"Embedding 维度错误,应为 {self.dim}") return vector_id
vec = self._normalize(embedding) if self.use_cosine else np.array(embedding, dtype=np.float32) def insert_global_with_id(self, vector_id: int, embedding: list[float]) -> int:
vec = vec.reshape(1, -1) vec = self._prepare_embedding(embedding)
ids = np.array([vector_id], dtype=np.int64)
self.global_index.add_with_ids(vec, ids)
self._persist_index(self.global_index, self.global_index_path)
return vector_id
distances, indices = self.index.search(vec, k) def insert_idx(self, idx: str, vector_id: int, embedding: list[float]) -> int:
index = self._get_conversation_index(idx)
vec = self._prepare_embedding(embedding)
ids = np.array([vector_id], dtype=np.int64)
index.add_with_ids(vec, ids)
self._persist_index(index, self._conversation_index_path(idx))
return vector_id
def _search(self, index, embedding: list[float], k: int = 5):
vec = self._prepare_embedding(embedding)
distances, indices = index.search(vec, k)
ids = indices[0].tolist()
raw_distances = distances[0].tolist()
normalized_distances = []
similarity_scores = []
for vector_id, distance in zip(ids, raw_distances):
if vector_id < 0:
normalized_distances.append(0.0)
similarity_scores.append(0.0)
continue
normalized_distances.append(distance)
similarity_scores.append(1 - distance if not self.use_cosine else distance)
return { return {
"ids": indices[0].tolist(), "ids": ids,
"distances": distances[0].tolist(), # 余弦时值越大越相似1.0=完全相同) "distances": normalized_distances,
"similarity_scores": [1 - d for d in distances[0].tolist()] if not self.use_cosine else distances[0].tolist() "similarity_scores": similarity_scores,
} }
def search(self, embedding: list[float], k: int = 5):
return self._search(self.global_index, embedding, k)
def search_idx(self, idx: str, embedding: list[float], k: int = 5):
index = self._get_conversation_index(idx)
return self._search(index, embedding, k)
def delete(self, idx: str) -> bool:
self._get_conversation_index(idx)
self.conversation_indexes.pop(idx, None)
index_path = self._conversation_index_path(idx)
if os.path.exists(index_path):
os.remove(index_path)
return True
def persist(self): def persist(self):
"""保存索引""" self._persist_index(self.global_index, self.global_index_path)
faiss.write_index(self.index, self.index_path) for idx, index in self.conversation_indexes.items():
print(f"💾 索引已保存 → {self.index_path}(共 {self.index.ntotal} 个向量)") self._persist_index(index, self._conversation_index_path(idx))
return True return True
# 单例
faiss_manager = FaissManager() faiss_manager = FaissManager()

1
faiss/indexes/.gitkeep Normal file
View File

@@ -0,0 +1 @@

View File

@@ -1,10 +1,28 @@
# models.py
from pydantic import BaseModel
from typing import List from typing import List
from pydantic import BaseModel
class EmbeddingInput(BaseModel): class EmbeddingInput(BaseModel):
embedding: List[float] embedding: List[float]
class SearchInput(BaseModel): class SearchInput(BaseModel):
embedding: List[float] embedding: List[float]
k: int = 5 k: int = 5
class IndexInsertInput(BaseModel):
idx: str
id: int
embedding: List[float]
class IndexSearchInput(BaseModel):
idx: str
embedding: List[float]
k: int = 5
class IndexDeleteInput(BaseModel):
idx: str