122 lines
3.0 KiB
Go
122 lines
3.0 KiB
Go
package vector_data
|
|
|
|
import (
|
|
"ai-chat-service/pkg/config"
|
|
"ai-chat-service/services/embedding"
|
|
"context"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx"
|
|
)
|
|
|
|
type pgvectorChatRecordsData struct {
|
|
config *config.Config
|
|
pool *pgx.ConnPool
|
|
embedder embedding.Embedder
|
|
}
|
|
|
|
func newPgvectorChatRecordsData(config *config.Config) (IChatRecordsData, error) {
|
|
connConfig, err := pgx.ParseConnectionString(config.Vector.Pgvector.DSN)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
pool, err := pgx.NewConnPool(pgx.ConnPoolConfig{
|
|
ConnConfig: connConfig,
|
|
MaxConnections: config.Vector.Pgvector.MaxOpenConn,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
embedder, err := embedding.NewEmbedder(config)
|
|
if err != nil {
|
|
pool.Close()
|
|
return nil, err
|
|
}
|
|
return &pgvectorChatRecordsData{
|
|
config: config,
|
|
pool: pool,
|
|
embedder: embedder,
|
|
}, nil
|
|
}
|
|
|
|
func (data *pgvectorChatRecordsData) UpsertData(ctx context.Context, list []*ChatRecord) error {
|
|
table := data.config.Vector.Pgvector.Table
|
|
if table == "" {
|
|
table = "chat_record_vectors"
|
|
}
|
|
for _, item := range list {
|
|
recordID, err := strconv.ParseInt(item.ID, 10, 64)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
keywordsText := embedding.BuildText(item.KVs["keywords"])
|
|
if keywordsText == "" {
|
|
continue
|
|
}
|
|
vector, err := data.embedder.Embed(ctx, keywordsText)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
commandTag, err := data.pool.Exec(
|
|
fmt.Sprintf(
|
|
"INSERT INTO %s (record_id, keywords_text, embedding, created_at) VALUES ($1, $2, $3::vector, $4) ON CONFLICT (record_id) DO UPDATE SET keywords_text = EXCLUDED.keywords_text, embedding = EXCLUDED.embedding, created_at = EXCLUDED.created_at",
|
|
table,
|
|
),
|
|
recordID,
|
|
keywordsText,
|
|
vectorLiteral(vector),
|
|
time.Now().Unix(),
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if commandTag.RowsAffected() == 0 {
|
|
return fmt.Errorf("pgvector upsert affected 0 rows for record_id=%d", recordID)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (data *pgvectorChatRecordsData) QueryData(ctx context.Context, text map[string][]string) (id string, score float32, err error) {
|
|
keywordsText := embedding.BuildText(text["keywords"]...)
|
|
if keywordsText == "" {
|
|
return "", 0, nil
|
|
}
|
|
vector, err := data.embedder.Embed(ctx, keywordsText)
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
table := data.config.Vector.Pgvector.Table
|
|
if table == "" {
|
|
table = "chat_record_vectors"
|
|
}
|
|
var recordID int64
|
|
err = data.pool.QueryRowEx(
|
|
ctx,
|
|
fmt.Sprintf(
|
|
"SELECT record_id, CAST(1 - (embedding <=> $1::vector) AS real) AS score FROM %s ORDER BY embedding <=> $1::vector LIMIT 1",
|
|
table,
|
|
),
|
|
nil,
|
|
vectorLiteral(vector),
|
|
).Scan(&recordID, &score)
|
|
if err != nil {
|
|
if err == pgx.ErrNoRows {
|
|
return "", 0, nil
|
|
}
|
|
return "", 0, err
|
|
}
|
|
return strconv.FormatInt(recordID, 10), score, nil
|
|
}
|
|
|
|
func vectorLiteral(values []float32) string {
|
|
parts := make([]string, 0, len(values))
|
|
for _, value := range values {
|
|
parts = append(parts, strconv.FormatFloat(float64(value), 'f', -1, 32))
|
|
}
|
|
return "[" + strings.Join(parts, ",") + "]"
|
|
}
|