redis缓存替换+pgvector向量替换
This commit is contained in:
@@ -3,7 +3,7 @@ package vector_data
|
||||
import (
|
||||
"ai-chat-service/pkg/config"
|
||||
"context"
|
||||
"github.com/tencent/vectordatabase-sdk-go/tcvectordb"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const CHAT_RECORDS = "chat_records"
|
||||
@@ -17,53 +17,13 @@ type IChatRecordsData interface {
|
||||
QueryData(ctx context.Context, text map[string][]string) (id string, score float32, err error)
|
||||
}
|
||||
|
||||
type chatRecordsData struct {
|
||||
config *config.Config
|
||||
vectorDB *tcvectordb.Client
|
||||
}
|
||||
|
||||
func NewChatRecordsData(config *config.Config, vectorDB *tcvectordb.Client) IChatRecordsData {
|
||||
return &chatRecordsData{
|
||||
config: config,
|
||||
vectorDB: vectorDB,
|
||||
func NewChatRecordsData(config *config.Config) (IChatRecordsData, error) {
|
||||
switch config.Vector.Provider {
|
||||
case "tencent", "":
|
||||
return newTencentChatRecordsData(config)
|
||||
case "pgvector":
|
||||
return newPgvectorChatRecordsData(config)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported vector provider: %s", config.Vector.Provider)
|
||||
}
|
||||
}
|
||||
func (data *chatRecordsData) UpsertData(ctx context.Context, list []*ChatRecord) error {
|
||||
database := data.config.VectorDB.Database
|
||||
collection := CHAT_RECORDS
|
||||
coll := data.vectorDB.Database(database).Collection(collection)
|
||||
documentList := make([]tcvectordb.Document, 0, len(list))
|
||||
for _, l := range list {
|
||||
doc := tcvectordb.Document{
|
||||
Id: l.ID,
|
||||
}
|
||||
doc.Fields = make(map[string]tcvectordb.Field, len(l.KVs))
|
||||
for k, v := range l.KVs {
|
||||
doc.Fields[k] = tcvectordb.Field{Val: v}
|
||||
}
|
||||
documentList = append(documentList, doc)
|
||||
}
|
||||
_, err := coll.Upsert(ctx, documentList)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (data *chatRecordsData) QueryData(ctx context.Context, text map[string][]string) (id string, score float32, err error) {
|
||||
database := data.config.VectorDB.Database
|
||||
collection := CHAT_RECORDS
|
||||
coll := data.vectorDB.Database(database).Collection(collection)
|
||||
result, err := coll.SearchByText(ctx, text, &tcvectordb.SearchDocumentParams{
|
||||
Params: &tcvectordb.SearchDocParams{Ef: 100},
|
||||
Limit: 1,
|
||||
})
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
if len(result.Documents) > 0 && len(result.Documents[0]) > 0 {
|
||||
doc := result.Documents[0][0]
|
||||
return doc.Id, doc.Score, nil
|
||||
|
||||
}
|
||||
return "", 0, nil
|
||||
}
|
||||
|
||||
121
ai-chat-service/chat-server/vector-data/pgvector.go
Normal file
121
ai-chat-service/chat-server/vector-data/pgvector.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package vector_data
|
||||
|
||||
import (
|
||||
"ai-chat-service/pkg/config"
|
||||
"ai-chat-service/services/embedding"
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
)
|
||||
|
||||
type pgvectorChatRecordsData struct {
|
||||
config *config.Config
|
||||
pool *pgx.ConnPool
|
||||
embedder embedding.Embedder
|
||||
}
|
||||
|
||||
func newPgvectorChatRecordsData(config *config.Config) (IChatRecordsData, error) {
|
||||
connConfig, err := pgx.ParseConnectionString(config.Vector.Pgvector.DSN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pool, err := pgx.NewConnPool(pgx.ConnPoolConfig{
|
||||
ConnConfig: connConfig,
|
||||
MaxConnections: config.Vector.Pgvector.MaxOpenConn,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
embedder, err := embedding.NewEmbedder(config)
|
||||
if err != nil {
|
||||
pool.Close()
|
||||
return nil, err
|
||||
}
|
||||
return &pgvectorChatRecordsData{
|
||||
config: config,
|
||||
pool: pool,
|
||||
embedder: embedder,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (data *pgvectorChatRecordsData) UpsertData(ctx context.Context, list []*ChatRecord) error {
|
||||
table := data.config.Vector.Pgvector.Table
|
||||
if table == "" {
|
||||
table = "chat_record_vectors"
|
||||
}
|
||||
for _, item := range list {
|
||||
recordID, err := strconv.ParseInt(item.ID, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
keywordsText := embedding.BuildText(item.KVs["keywords"])
|
||||
if keywordsText == "" {
|
||||
continue
|
||||
}
|
||||
vector, err := data.embedder.Embed(ctx, keywordsText)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
commandTag, err := data.pool.Exec(
|
||||
fmt.Sprintf(
|
||||
"INSERT INTO %s (record_id, keywords_text, embedding, created_at) VALUES ($1, $2, $3::vector, $4) ON CONFLICT (record_id) DO UPDATE SET keywords_text = EXCLUDED.keywords_text, embedding = EXCLUDED.embedding, created_at = EXCLUDED.created_at",
|
||||
table,
|
||||
),
|
||||
recordID,
|
||||
keywordsText,
|
||||
vectorLiteral(vector),
|
||||
time.Now().Unix(),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if commandTag.RowsAffected() == 0 {
|
||||
return fmt.Errorf("pgvector upsert affected 0 rows for record_id=%d", recordID)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (data *pgvectorChatRecordsData) QueryData(ctx context.Context, text map[string][]string) (id string, score float32, err error) {
|
||||
keywordsText := embedding.BuildText(text["keywords"]...)
|
||||
if keywordsText == "" {
|
||||
return "", 0, nil
|
||||
}
|
||||
vector, err := data.embedder.Embed(ctx, keywordsText)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
table := data.config.Vector.Pgvector.Table
|
||||
if table == "" {
|
||||
table = "chat_record_vectors"
|
||||
}
|
||||
var recordID int64
|
||||
err = data.pool.QueryRowEx(
|
||||
ctx,
|
||||
fmt.Sprintf(
|
||||
"SELECT record_id, CAST(1 - (embedding <=> $1::vector) AS real) AS score FROM %s ORDER BY embedding <=> $1::vector LIMIT 1",
|
||||
table,
|
||||
),
|
||||
nil,
|
||||
vectorLiteral(vector),
|
||||
).Scan(&recordID, &score)
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return "", 0, nil
|
||||
}
|
||||
return "", 0, err
|
||||
}
|
||||
return strconv.FormatInt(recordID, 10), score, nil
|
||||
}
|
||||
|
||||
func vectorLiteral(values []float32) string {
|
||||
parts := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
parts = append(parts, strconv.FormatFloat(float64(value), 'f', -1, 32))
|
||||
}
|
||||
return "[" + strings.Join(parts, ",") + "]"
|
||||
}
|
||||
66
ai-chat-service/chat-server/vector-data/tencent.go
Normal file
66
ai-chat-service/chat-server/vector-data/tencent.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package vector_data
|
||||
|
||||
import (
|
||||
"ai-chat-service/pkg/config"
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/tencent/vectordatabase-sdk-go/tcvectordb"
|
||||
)
|
||||
|
||||
type tencentChatRecordsData struct {
|
||||
config *config.Config
|
||||
vectorDB *tcvectordb.Client
|
||||
}
|
||||
|
||||
func newTencentChatRecordsData(config *config.Config) (IChatRecordsData, error) {
|
||||
option := &tcvectordb.ClientOption{
|
||||
Timeout: time.Second * time.Duration(config.Vector.Tencent.Timeout),
|
||||
MaxIdldConnPerHost: config.Vector.Tencent.MaxIdleConnPerHost,
|
||||
IdleConnTimeout: time.Second * time.Duration(config.Vector.Tencent.IdleConnTimeout),
|
||||
ReadConsistency: tcvectordb.ReadConsistency(config.Vector.Tencent.ReadConsistency),
|
||||
}
|
||||
client, err := tcvectordb.NewClient(config.Vector.Tencent.Url, config.Vector.Tencent.Username, config.Vector.Tencent.Pwd, option)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tencentChatRecordsData{
|
||||
config: config,
|
||||
vectorDB: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (data *tencentChatRecordsData) UpsertData(ctx context.Context, list []*ChatRecord) error {
|
||||
database := data.config.Vector.Tencent.Database
|
||||
collection := CHAT_RECORDS
|
||||
coll := data.vectorDB.Database(database).Collection(collection)
|
||||
documentList := make([]tcvectordb.Document, 0, len(list))
|
||||
for _, l := range list {
|
||||
doc := tcvectordb.Document{Id: l.ID}
|
||||
doc.Fields = make(map[string]tcvectordb.Field, len(l.KVs))
|
||||
for k, v := range l.KVs {
|
||||
doc.Fields[k] = tcvectordb.Field{Val: v}
|
||||
}
|
||||
documentList = append(documentList, doc)
|
||||
}
|
||||
_, err := coll.Upsert(ctx, documentList)
|
||||
return err
|
||||
}
|
||||
|
||||
func (data *tencentChatRecordsData) QueryData(ctx context.Context, text map[string][]string) (id string, score float32, err error) {
|
||||
database := data.config.Vector.Tencent.Database
|
||||
collection := CHAT_RECORDS
|
||||
coll := data.vectorDB.Database(database).Collection(collection)
|
||||
result, err := coll.SearchByText(ctx, text, &tcvectordb.SearchDocumentParams{
|
||||
Params: &tcvectordb.SearchDocParams{Ef: 100},
|
||||
Limit: 1,
|
||||
})
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
if len(result.Documents) > 0 && len(result.Documents[0]) > 0 {
|
||||
doc := result.Documents[0][0]
|
||||
return doc.Id, doc.Score, nil
|
||||
}
|
||||
return "", 0, nil
|
||||
}
|
||||
Reference in New Issue
Block a user