redis缓存替换+pgvector向量替换
This commit is contained in:
115
ai-chat-service/services/embedding/embedding.go
Normal file
115
ai-chat-service/services/embedding/embedding.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package embedding
|
||||
|
||||
import (
|
||||
"ai-chat-service/pkg/config"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Embedder interface {
|
||||
Embed(ctx context.Context, text string) ([]float32, error)
|
||||
}
|
||||
|
||||
type openAICompatibleEmbedder struct {
|
||||
baseURL string
|
||||
apiKey string
|
||||
model string
|
||||
dimensions int
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
type embeddingRequest struct {
|
||||
Input []string `json:"input"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
type embeddingResponse struct {
|
||||
Data []struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
func NewEmbedder(cnf *config.Config) (Embedder, error) {
|
||||
switch cnf.Embedding.Provider {
|
||||
case "openai-compatible", "":
|
||||
return &openAICompatibleEmbedder{
|
||||
baseURL: strings.TrimRight(cnf.Embedding.BaseUrl, "/"),
|
||||
apiKey: cnf.Embedding.ApiKey,
|
||||
model: cnf.Embedding.Model,
|
||||
dimensions: cnf.Vector.Pgvector.Dimensions,
|
||||
httpClient: &http.Client{Timeout: time.Duration(cnf.Embedding.Timeout) * time.Second},
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported embedding provider: %s", cnf.Embedding.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
func BuildText(parts ...string) string {
|
||||
list := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(strings.ReplaceAll(part, "\n", " "))
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
list = append(list, part)
|
||||
}
|
||||
return strings.Join(list, ",")
|
||||
}
|
||||
|
||||
func (e *openAICompatibleEmbedder) Embed(ctx context.Context, text string) ([]float32, error) {
|
||||
text = BuildText(text)
|
||||
if text == "" {
|
||||
return nil, fmt.Errorf("embedding text is empty")
|
||||
}
|
||||
if e.baseURL == "" {
|
||||
return nil, fmt.Errorf("embedding base_url is empty")
|
||||
}
|
||||
if e.apiKey == "" {
|
||||
return nil, fmt.Errorf("embedding api_key is empty")
|
||||
}
|
||||
if e.model == "" {
|
||||
return nil, fmt.Errorf("embedding model is empty")
|
||||
}
|
||||
|
||||
reqBody := &embeddingRequest{
|
||||
Input: []string{text},
|
||||
Model: e.model,
|
||||
}
|
||||
body, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, e.baseURL+"/embeddings", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+e.apiKey)
|
||||
|
||||
resp, err := e.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("embedding request failed: status=%d", resp.StatusCode)
|
||||
}
|
||||
|
||||
result := &embeddingResponse{}
|
||||
if err = json.NewDecoder(resp.Body).Decode(result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(result.Data) == 0 || len(result.Data[0].Embedding) == 0 {
|
||||
return nil, fmt.Errorf("embedding response is empty")
|
||||
}
|
||||
if e.dimensions > 0 && len(result.Data[0].Embedding) != e.dimensions {
|
||||
return nil, fmt.Errorf("embedding dimension mismatch: got=%d want=%d", len(result.Data[0].Embedding), e.dimensions)
|
||||
}
|
||||
return result.Data[0].Embedding, nil
|
||||
}
|
||||
Reference in New Issue
Block a user