90 lines
2.2 KiB
Go
90 lines
2.2 KiB
Go
package faiss
|
|
|
|
import (
|
|
"ai-chat-service/pkg/config"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type Client interface {
|
|
Insert(ctx context.Context, embedding []float32) (string, error)
|
|
Search(ctx context.Context, embedding []float32, k int) (*SearchResponse, error)
|
|
}
|
|
|
|
type client struct {
|
|
baseURL string
|
|
httpClient *http.Client
|
|
}
|
|
|
|
type SearchResponse struct {
|
|
IDs []int64 `json:"ids"`
|
|
Distances []float32 `json:"distances"`
|
|
SimilarityScores []float32 `json:"similarity_scores"`
|
|
}
|
|
|
|
type insertRequest struct {
|
|
Embedding []float32 `json:"embedding"`
|
|
}
|
|
|
|
type insertResponse struct {
|
|
ID int64 `json:"id"`
|
|
}
|
|
|
|
type searchRequest struct {
|
|
Embedding []float32 `json:"embedding"`
|
|
K int `json:"k"`
|
|
}
|
|
|
|
func NewClient(cnf *config.Config) Client {
|
|
return &client{
|
|
baseURL: strings.TrimRight(cnf.Faiss.BaseUrl, "/"),
|
|
httpClient: &http.Client{Timeout: time.Duration(cnf.Faiss.Timeout) * time.Second},
|
|
}
|
|
}
|
|
|
|
func (c *client) Insert(ctx context.Context, embedding []float32) (string, error) {
|
|
reqBody := &insertRequest{Embedding: embedding}
|
|
result := &insertResponse{}
|
|
if err := c.postJSON(ctx, "/insert", reqBody, result); err != nil {
|
|
return "", err
|
|
}
|
|
return fmt.Sprintf("%d", result.ID), nil
|
|
}
|
|
|
|
func (c *client) Search(ctx context.Context, embedding []float32, k int) (*SearchResponse, error) {
|
|
reqBody := &searchRequest{Embedding: embedding, K: k}
|
|
result := &SearchResponse{}
|
|
if err := c.postJSON(ctx, "/search", reqBody, result); err != nil {
|
|
return nil, err
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (c *client) postJSON(ctx context.Context, path string, requestData any, responseData any) error {
|
|
body, err := json.Marshal(requestData)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+path, bytes.NewReader(body))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return fmt.Errorf("faiss request failed: status=%d", resp.StatusCode)
|
|
}
|
|
return json.NewDecoder(resp.Body).Decode(responseData)
|
|
}
|