116 lines
2.9 KiB
Go
116 lines
2.9 KiB
Go
package embedding
|
|
|
|
import (
|
|
"ai-chat-service/pkg/config"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type Embedder interface {
|
|
Embed(ctx context.Context, text string) ([]float32, error)
|
|
}
|
|
|
|
type openAICompatibleEmbedder struct {
|
|
baseURL string
|
|
apiKey string
|
|
model string
|
|
dimensions int
|
|
httpClient *http.Client
|
|
}
|
|
|
|
type embeddingRequest struct {
|
|
Input []string `json:"input"`
|
|
Model string `json:"model"`
|
|
}
|
|
|
|
type embeddingResponse struct {
|
|
Data []struct {
|
|
Embedding []float32 `json:"embedding"`
|
|
} `json:"data"`
|
|
}
|
|
|
|
func NewEmbedder(cnf *config.Config) (Embedder, error) {
|
|
switch cnf.Embedding.Provider {
|
|
case "openai-compatible", "":
|
|
return &openAICompatibleEmbedder{
|
|
baseURL: strings.TrimRight(cnf.Embedding.BaseUrl, "/"),
|
|
apiKey: cnf.Embedding.ApiKey,
|
|
model: cnf.Embedding.Model,
|
|
dimensions: cnf.Vector.Pgvector.Dimensions,
|
|
httpClient: &http.Client{Timeout: time.Duration(cnf.Embedding.Timeout) * time.Second},
|
|
}, nil
|
|
default:
|
|
return nil, fmt.Errorf("unsupported embedding provider: %s", cnf.Embedding.Provider)
|
|
}
|
|
}
|
|
|
|
func BuildText(parts ...string) string {
|
|
list := make([]string, 0, len(parts))
|
|
for _, part := range parts {
|
|
part = strings.TrimSpace(strings.ReplaceAll(part, "\n", " "))
|
|
if part == "" {
|
|
continue
|
|
}
|
|
list = append(list, part)
|
|
}
|
|
return strings.Join(list, ",")
|
|
}
|
|
|
|
func (e *openAICompatibleEmbedder) Embed(ctx context.Context, text string) ([]float32, error) {
|
|
text = BuildText(text)
|
|
if text == "" {
|
|
return nil, fmt.Errorf("embedding text is empty")
|
|
}
|
|
if e.baseURL == "" {
|
|
return nil, fmt.Errorf("embedding base_url is empty")
|
|
}
|
|
if e.apiKey == "" {
|
|
return nil, fmt.Errorf("embedding api_key is empty")
|
|
}
|
|
if e.model == "" {
|
|
return nil, fmt.Errorf("embedding model is empty")
|
|
}
|
|
|
|
reqBody := &embeddingRequest{
|
|
Input: []string{text},
|
|
Model: e.model,
|
|
}
|
|
body, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, e.baseURL+"/embeddings", bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+e.apiKey)
|
|
|
|
resp, err := e.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return nil, fmt.Errorf("embedding request failed: status=%d", resp.StatusCode)
|
|
}
|
|
|
|
result := &embeddingResponse{}
|
|
if err = json.NewDecoder(resp.Body).Decode(result); err != nil {
|
|
return nil, err
|
|
}
|
|
if len(result.Data) == 0 || len(result.Data[0].Embedding) == 0 {
|
|
return nil, fmt.Errorf("embedding response is empty")
|
|
}
|
|
if e.dimensions > 0 && len(result.Data[0].Embedding) != e.dimensions {
|
|
return nil, fmt.Errorf("embedding dimension mismatch: got=%d want=%d", len(result.Data[0].Embedding), e.dimensions)
|
|
}
|
|
return result.Data[0].Embedding, nil
|
|
}
|