package vector_data import ( "ai-chat-service/pkg/config" "ai-chat-service/services/embedding" "context" "fmt" "strconv" "strings" "time" "github.com/jackc/pgx" ) type pgvectorChatRecordsData struct { config *config.Config pool *pgx.ConnPool embedder embedding.Embedder } func newPgvectorChatRecordsData(config *config.Config) (IChatRecordsData, error) { connConfig, err := pgx.ParseConnectionString(config.Vector.Pgvector.DSN) if err != nil { return nil, err } pool, err := pgx.NewConnPool(pgx.ConnPoolConfig{ ConnConfig: connConfig, MaxConnections: config.Vector.Pgvector.MaxOpenConn, }) if err != nil { return nil, err } embedder, err := embedding.NewEmbedder(config) if err != nil { pool.Close() return nil, err } return &pgvectorChatRecordsData{ config: config, pool: pool, embedder: embedder, }, nil } func (data *pgvectorChatRecordsData) UpsertData(ctx context.Context, list []*ChatRecord) error { table := data.config.Vector.Pgvector.Table if table == "" { table = "chat_record_vectors" } for _, item := range list { recordID, err := strconv.ParseInt(item.ID, 10, 64) if err != nil { return err } keywordsText := embedding.BuildText(item.KVs["keywords"]) if keywordsText == "" { continue } vector, err := data.embedder.Embed(ctx, keywordsText) if err != nil { return err } commandTag, err := data.pool.Exec( fmt.Sprintf( "INSERT INTO %s (record_id, keywords_text, embedding, created_at) VALUES ($1, $2, $3::vector, $4) ON CONFLICT (record_id) DO UPDATE SET keywords_text = EXCLUDED.keywords_text, embedding = EXCLUDED.embedding, created_at = EXCLUDED.created_at", table, ), recordID, keywordsText, vectorLiteral(vector), time.Now().Unix(), ) if err != nil { return err } if commandTag.RowsAffected() == 0 { return fmt.Errorf("pgvector upsert affected 0 rows for record_id=%d", recordID) } } return nil } func (data *pgvectorChatRecordsData) QueryData(ctx context.Context, text map[string][]string) (id string, score float32, err error) { keywordsText := embedding.BuildText(text["keywords"]...) if keywordsText == "" { return "", 0, nil } vector, err := data.embedder.Embed(ctx, keywordsText) if err != nil { return "", 0, err } table := data.config.Vector.Pgvector.Table if table == "" { table = "chat_record_vectors" } var recordID int64 err = data.pool.QueryRowEx( ctx, fmt.Sprintf( "SELECT record_id, CAST(1 - (embedding <=> $1::vector) AS real) AS score FROM %s ORDER BY embedding <=> $1::vector LIMIT 1", table, ), nil, vectorLiteral(vector), ).Scan(&recordID, &score) if err != nil { if err == pgx.ErrNoRows { return "", 0, nil } return "", 0, err } return strconv.FormatInt(recordID, 10), score, nil } func vectorLiteral(values []float32) string { parts := make([]string, 0, len(values)) for _, value := range values { parts = append(parts, strconv.FormatFloat(float64(value), 'f', -1, 32)) } return "[" + strings.Join(parts, ",") + "]" }