Files
mchat/ai-chat-service/services/faiss/faiss.go
2026-04-10 11:55:00 +00:00

90 lines
2.2 KiB
Go

package faiss
import (
"ai-chat-service/pkg/config"
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
)
type Client interface {
Insert(ctx context.Context, embedding []float32) (string, error)
Search(ctx context.Context, embedding []float32, k int) (*SearchResponse, error)
}
type client struct {
baseURL string
httpClient *http.Client
}
type SearchResponse struct {
IDs []int64 `json:"ids"`
Distances []float32 `json:"distances"`
SimilarityScores []float32 `json:"similarity_scores"`
}
type insertRequest struct {
Embedding []float32 `json:"embedding"`
}
type insertResponse struct {
ID int64 `json:"id"`
}
type searchRequest struct {
Embedding []float32 `json:"embedding"`
K int `json:"k"`
}
func NewClient(cnf *config.Config) Client {
return &client{
baseURL: strings.TrimRight(cnf.Faiss.BaseUrl, "/"),
httpClient: &http.Client{Timeout: time.Duration(cnf.Faiss.Timeout) * time.Second},
}
}
func (c *client) Insert(ctx context.Context, embedding []float32) (string, error) {
reqBody := &insertRequest{Embedding: embedding}
result := &insertResponse{}
if err := c.postJSON(ctx, "/insert", reqBody, result); err != nil {
return "", err
}
return fmt.Sprintf("%d", result.ID), nil
}
func (c *client) Search(ctx context.Context, embedding []float32, k int) (*SearchResponse, error) {
reqBody := &searchRequest{Embedding: embedding, K: k}
result := &SearchResponse{}
if err := c.postJSON(ctx, "/search", reqBody, result); err != nil {
return nil, err
}
return result, nil
}
func (c *client) postJSON(ctx context.Context, path string, requestData any, responseData any) error {
body, err := json.Marshal(requestData)
if err != nil {
return err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+path, bytes.NewReader(body))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("faiss request failed: status=%d", resp.StatusCode)
}
return json.NewDecoder(resp.Body).Decode(responseData)
}