45 lines
1.2 KiB
Python
45 lines
1.2 KiB
Python
# api.py
|
|
from fastapi import FastAPI, Depends, HTTPException
|
|
from models import EmbeddingInput, SearchInput
|
|
from faiss_manager import faiss_manager
|
|
from config import get_settings
|
|
|
|
settings = get_settings()
|
|
|
|
app = FastAPI(
|
|
title="FAISS 服务",
|
|
description="向量插入 + 相似搜索 + 持久化",
|
|
version="1.0.0"
|
|
)
|
|
|
|
def log_business(message: str):
|
|
if settings.ENABLE_REQUEST_LOGS:
|
|
print(message, flush=True)
|
|
|
|
@app.post("/insert")
|
|
async def insert(data: EmbeddingInput):
|
|
try:
|
|
vector_id = faiss_manager.insert(data.embedding)
|
|
log_business(f"[faiss] insert id={vector_id}")
|
|
return {"id": vector_id}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
@app.post("/search")
|
|
async def search(data: SearchInput):
|
|
try:
|
|
result = faiss_manager.search(data.embedding, data.k)
|
|
log_business(
|
|
f"[faiss] search ids={result['ids']} similarity_scores={result['similarity_scores']}",
|
|
)
|
|
return result
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
@app.post("/persist")
|
|
async def persist():
|
|
faiss_manager.persist()
|
|
return {"status": "success", "message": "索引已持久化"}
|