service 修改 Redis 存储 KV
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -2,3 +2,5 @@ docs
|
|||||||
.workspace.codex
|
.workspace.codex
|
||||||
|
|
||||||
.env
|
.env
|
||||||
|
.vscode
|
||||||
|
faiss
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
package chat_context
|
|
||||||
|
|
||||||
import "github.com/sashabaranov/go-openai"
|
|
||||||
|
|
||||||
type ChatMessage struct {
|
|
||||||
//当前记录ID
|
|
||||||
ID string `json:"id,omitempty"`
|
|
||||||
//上一条记录ID
|
|
||||||
PID string `json:"pid,omitempty"`
|
|
||||||
//消息内容
|
|
||||||
Message openai.ChatCompletionMessage `json:"message"`
|
|
||||||
//该消息tokens数
|
|
||||||
Tokens int `json:"tokens,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ContextCache interface {
|
|
||||||
Get(key string) (*ChatMessage, error)
|
|
||||||
Set(key string, value *ChatMessage, ttl int) error
|
|
||||||
Close()
|
|
||||||
}
|
|
||||||
@@ -1,51 +0,0 @@
|
|||||||
package chat_context
|
|
||||||
|
|
||||||
import (
|
|
||||||
predis "ai-chat-service/pkg/db/redis"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/redis/go-redis/v9"
|
|
||||||
)
|
|
||||||
|
|
||||||
type redisCache struct {
|
|
||||||
redisClient *redis.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRedisCache() ContextCache {
|
|
||||||
pool := predis.GetPool()
|
|
||||||
return &redisCache{
|
|
||||||
redisClient: pool.Get(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
func getRedisKey(key string) string {
|
|
||||||
return predis.GetKey(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *redisCache) Get(key string) (*ChatMessage, error) {
|
|
||||||
key = getRedisKey(key)
|
|
||||||
str, err := c.redisClient.Get(context.Background(), key).Result()
|
|
||||||
if err == redis.Nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
value := &ChatMessage{}
|
|
||||||
err = json.Unmarshal([]byte(str), value)
|
|
||||||
return value, err
|
|
||||||
}
|
|
||||||
func (c *redisCache) Set(key string, value *ChatMessage, ttl int) error {
|
|
||||||
key = getRedisKey(key)
|
|
||||||
bytes, err := json.Marshal(value)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
str := string(bytes)
|
|
||||||
return c.redisClient.SetEx(context.Background(), key, str, time.Duration(ttl)*time.Second).Err()
|
|
||||||
}
|
|
||||||
func (c *redisCache) Close() {
|
|
||||||
pool := predis.GetPool()
|
|
||||||
pool.Put(c.redisClient)
|
|
||||||
}
|
|
||||||
@@ -1,54 +1,59 @@
|
|||||||
package data
|
package data
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
predis "ai-chat-service/pkg/db/redis"
|
||||||
"strings"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
redis "github.com/redis/go-redis/v9"
|
||||||
)
|
)
|
||||||
|
|
||||||
type IChatRecordsData interface {
|
type IChatRecordsData interface {
|
||||||
Add(record *ChatRecord) error
|
Add(record *ChatRecord) error
|
||||||
GetById(id int64) (record *ChatRecord, err error)
|
GetById(id string) (record *ChatRecord, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatRecord struct {
|
type ChatRecord struct {
|
||||||
ID int64 `json:"id"`
|
ID string `json:"-"`
|
||||||
UserMsg string `json:"user_msg"`
|
Question string `json:"q"`
|
||||||
UserMsgTokens int `json:"user_msg_tokens"`
|
Answer string `json:"a"`
|
||||||
UserMsgKeywords []string `json:"user_msg_keywords"`
|
|
||||||
AIMsg string `json:"ai_msg"`
|
|
||||||
AIMsgTokens int `json:"ai_msg_tokens"`
|
|
||||||
ReqTokens int `json:"req_tokens"`
|
|
||||||
CreateAt int64 `json:"create_at"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type chatRecordsData struct {
|
type chatRecordsData struct{}
|
||||||
db *sql.DB
|
|
||||||
|
func NewChatRecordsData() IChatRecordsData {
|
||||||
|
return &chatRecordsData{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChatRecordsData(db *sql.DB) IChatRecordsData {
|
func (data *chatRecordsData) Add(record *ChatRecord) error {
|
||||||
return &chatRecordsData{
|
client := predis.GetPool().Get()
|
||||||
db: db,
|
defer predis.GetPool().Put(client)
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (data *chatRecordsData) Add(cr *ChatRecord) (err error) {
|
payload, err := json.Marshal(&ChatRecord{
|
||||||
sqlStr := "insert into chat_records(user_msg,user_msg_tokens,user_msg_keywords,ai_msg,ai_msg_tokens,req_tokens,create_at)values(?,?,?,?,?,?,?)"
|
Question: record.Question,
|
||||||
res, err := data.db.Exec(sqlStr, cr.UserMsg, cr.UserMsgTokens, strings.Join(cr.UserMsgKeywords, ","), cr.AIMsg, cr.AIMsgTokens, cr.ReqTokens, cr.CreateAt)
|
Answer: record.Answer,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
cr.ID, _ = res.LastInsertId()
|
return client.Set(context.Background(), predis.GetKey("qa", record.ID), string(payload), 0).Err()
|
||||||
return
|
|
||||||
}
|
}
|
||||||
func (data *chatRecordsData) GetById(id int64) (cr *ChatRecord, err error) {
|
|
||||||
sqlStr := "select id,user_msg,user_msg_tokens,user_msg_keywords,ai_msg,ai_msg_tokens,req_tokens,create_at from chat_records where id = ?"
|
func (data *chatRecordsData) GetById(id string) (*ChatRecord, error) {
|
||||||
row := data.db.QueryRow(sqlStr, id)
|
client := predis.GetPool().Get()
|
||||||
cr = &ChatRecord{}
|
defer predis.GetPool().Put(client)
|
||||||
var keywords string
|
|
||||||
err = row.Scan(&cr.ID, &cr.UserMsg, &cr.UserMsgTokens, &keywords, &cr.AIMsg, &cr.AIMsgTokens, &cr.ReqTokens, &cr.CreateAt)
|
value, err := client.Get(context.Background(), predis.GetKey("qa", id)).Result()
|
||||||
|
if err == redis.Nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
cr.UserMsgKeywords = strings.Split(keywords, ",")
|
|
||||||
return cr, err
|
record := &ChatRecord{ID: id}
|
||||||
|
if err = json.Unmarshal([]byte(value), record); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return record, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,13 +5,13 @@ import (
|
|||||||
metrics_app "ai-chat-service/chat-server/metrics-app"
|
metrics_app "ai-chat-service/chat-server/metrics-app"
|
||||||
metrics_bus "ai-chat-service/chat-server/metrics-bus"
|
metrics_bus "ai-chat-service/chat-server/metrics-bus"
|
||||||
"ai-chat-service/chat-server/server"
|
"ai-chat-service/chat-server/server"
|
||||||
vector_data "ai-chat-service/chat-server/vector-data"
|
|
||||||
"ai-chat-service/interceptor"
|
"ai-chat-service/interceptor"
|
||||||
"ai-chat-service/pkg/config"
|
"ai-chat-service/pkg/config"
|
||||||
"ai-chat-service/pkg/db/mysql"
|
|
||||||
"ai-chat-service/pkg/db/redis"
|
"ai-chat-service/pkg/db/redis"
|
||||||
"ai-chat-service/pkg/log"
|
"ai-chat-service/pkg/log"
|
||||||
"ai-chat-service/proto"
|
"ai-chat-service/proto"
|
||||||
|
"ai-chat-service/services/embedding"
|
||||||
|
"ai-chat-service/services/faiss"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -52,23 +52,22 @@ func main() {
|
|||||||
logger.SetOutput(log.GetRotateWriter(cnf.Log.LogPath))
|
logger.SetOutput(log.GetRotateWriter(cnf.Log.LogPath))
|
||||||
logger.SetPrintCaller(true)
|
logger.SetPrintCaller(true)
|
||||||
|
|
||||||
// 初始化Mysql
|
|
||||||
mysql.InitMysql(cnf)
|
|
||||||
// 初始化redis
|
// 初始化redis
|
||||||
redis.InitRedisPool(cnf)
|
redis.InitRedisPool(cnf)
|
||||||
|
|
||||||
recordsData := data.NewChatRecordsData(mysql.GetDB())
|
recordsData := data.NewChatRecordsData()
|
||||||
vectorRecordsData, err := vector_data.NewChatRecordsData(cnf)
|
embedder, err := embedding.NewEmbedder(cnf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
faissClient := faiss.NewClient(cnf)
|
||||||
|
|
||||||
lis, err := net.Listen("tcp", fmt.Sprintf("%s:%d", cnf.Server.IP, cnf.Server.Port))
|
lis, err := net.Listen("tcp", fmt.Sprintf("%s:%d", cnf.Server.IP, cnf.Server.Port))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
s := grpc.NewServer(grpc.UnaryInterceptor(interceptor.UnaryAuthInterceptor), grpc.StreamInterceptor(metrics_app.NewStreamMiddleware(registry).WrapHandler()))
|
s := grpc.NewServer(grpc.UnaryInterceptor(interceptor.UnaryAuthInterceptor), grpc.StreamInterceptor(metrics_app.NewStreamMiddleware(registry).WrapHandler()))
|
||||||
service := server.NewChatService(recordsData, vectorRecordsData, cnf, logger, busMetrics)
|
service := server.NewChatService(recordsData, embedder, faissClient, cnf, logger, busMetrics)
|
||||||
proto.RegisterChatServer(s, service)
|
proto.RegisterChatServer(s, service)
|
||||||
|
|
||||||
healthCheckSrv := health.NewServer()
|
healthCheckSrv := health.NewServer()
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
chat_context "ai-chat-service/chat-server/chat-context"
|
|
||||||
"ai-chat-service/pkg/config"
|
"ai-chat-service/pkg/config"
|
||||||
"ai-chat-service/pkg/log"
|
"ai-chat-service/pkg/log"
|
||||||
"ai-chat-service/pkg/zerror"
|
"ai-chat-service/pkg/zerror"
|
||||||
@@ -29,18 +28,15 @@ type openaiConf struct {
|
|||||||
PresencePenalty float32
|
PresencePenalty float32
|
||||||
FrequencyPenalty float32
|
FrequencyPenalty float32
|
||||||
BotDesc string
|
BotDesc string
|
||||||
ContextTTL int
|
|
||||||
ContextLen int
|
|
||||||
MinResponseTokens int
|
MinResponseTokens int
|
||||||
}
|
}
|
||||||
|
|
||||||
type app struct {
|
type app struct {
|
||||||
openaiConf *openaiConf
|
openaiConf *openaiConf
|
||||||
log log.ILogger
|
log log.ILogger
|
||||||
// TODO 内容上下文对象
|
|
||||||
contextCache chat_context.ContextCache
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *chatService) newApp(in *proto.ChatCompletionRequest, contextCache chat_context.ContextCache) *app {
|
func (s *chatService) newApp(in *proto.ChatCompletionRequest) *app {
|
||||||
conf := &openaiConf{
|
conf := &openaiConf{
|
||||||
ApiKey: s.config.Chat.ApiKey,
|
ApiKey: s.config.Chat.ApiKey,
|
||||||
BaseUrl: s.config.Chat.BaseUrl,
|
BaseUrl: s.config.Chat.BaseUrl,
|
||||||
@@ -51,8 +47,6 @@ func (s *chatService) newApp(in *proto.ChatCompletionRequest, contextCache chat_
|
|||||||
PresencePenalty: s.config.Chat.PresencePenalty,
|
PresencePenalty: s.config.Chat.PresencePenalty,
|
||||||
FrequencyPenalty: s.config.Chat.FrequencyPenalty,
|
FrequencyPenalty: s.config.Chat.FrequencyPenalty,
|
||||||
BotDesc: s.config.Chat.BotDesc,
|
BotDesc: s.config.Chat.BotDesc,
|
||||||
ContextTTL: s.config.Chat.ContextTTL,
|
|
||||||
ContextLen: s.config.Chat.ContextLen,
|
|
||||||
MinResponseTokens: s.config.Chat.MinResponseTokens,
|
MinResponseTokens: s.config.Chat.MinResponseTokens,
|
||||||
}
|
}
|
||||||
if in.ChatParam != nil {
|
if in.ChatParam != nil {
|
||||||
@@ -69,40 +63,29 @@ func (s *chatService) newApp(in *proto.ChatCompletionRequest, contextCache chat_
|
|||||||
if in.ChatParam.MaxTokens != 0 {
|
if in.ChatParam.MaxTokens != 0 {
|
||||||
conf.MaxTokens = int(in.ChatParam.MaxTokens)
|
conf.MaxTokens = int(in.ChatParam.MaxTokens)
|
||||||
}
|
}
|
||||||
if in.ChatParam.ContextTTL != 0 {
|
|
||||||
conf.ContextTTL = int(in.ChatParam.ContextTTL)
|
|
||||||
}
|
|
||||||
if in.ChatParam.ContextLen != 0 {
|
|
||||||
conf.ContextLen = int(in.ChatParam.ContextLen)
|
|
||||||
}
|
|
||||||
if in.ChatParam.MinResponseTokens != 0 {
|
if in.ChatParam.MinResponseTokens != 0 {
|
||||||
conf.MinResponseTokens = int(in.ChatParam.MinResponseTokens)
|
conf.MinResponseTokens = int(in.ChatParam.MinResponseTokens)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &app{
|
return &app{
|
||||||
openaiConf: conf,
|
openaiConf: conf,
|
||||||
log: s.log,
|
log: s.log,
|
||||||
contextCache: contextCache,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *app) getOpenaiClient() *openai.Client {
|
func (a *app) getOpenaiClient() *openai.Client {
|
||||||
accessToken := a.openaiConf.ApiKey
|
conf := openai.DefaultConfig(a.openaiConf.ApiKey)
|
||||||
config := openai.DefaultConfig(accessToken)
|
conf.BaseURL = a.openaiConf.BaseUrl
|
||||||
config.BaseURL = a.openaiConf.BaseUrl
|
return openai.NewClientWithConfig(conf)
|
||||||
client := openai.NewClientWithConfig(config)
|
|
||||||
return client
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *app) buildChatCompletionRequest(in *proto.ChatCompletionRequest, stream bool) (req openai.ChatCompletionRequest, tokens, currTokens int, currMessage openai.ChatCompletionMessage, err error) {
|
func (a *app) buildChatCompletionRequest(in *proto.ChatCompletionRequest, stream bool) (req openai.ChatCompletionRequest, tokens, currTokens int, currMessage openai.ChatCompletionMessage, err error) {
|
||||||
//当前消息
|
|
||||||
currMessage = openai.ChatCompletionMessage{
|
currMessage = openai.ChatCompletionMessage{
|
||||||
Role: openai.ChatMessageRoleUser,
|
Role: openai.ChatMessageRoleUser,
|
||||||
Content: in.Message,
|
Content: in.Message,
|
||||||
}
|
}
|
||||||
req = openai.ChatCompletionRequest{
|
req = openai.ChatCompletionRequest{
|
||||||
Model: a.openaiConf.Model,
|
Model: a.openaiConf.Model,
|
||||||
Messages: []openai.ChatCompletionMessage{
|
|
||||||
currMessage,
|
|
||||||
},
|
|
||||||
MaxTokens: a.openaiConf.MinResponseTokens,
|
MaxTokens: a.openaiConf.MinResponseTokens,
|
||||||
Temperature: a.openaiConf.Temperature,
|
Temperature: a.openaiConf.Temperature,
|
||||||
TopP: a.openaiConf.TopP,
|
TopP: a.openaiConf.TopP,
|
||||||
@@ -110,13 +93,7 @@ func (a *app) buildChatCompletionRequest(in *proto.ChatCompletionRequest, stream
|
|||||||
FrequencyPenalty: a.openaiConf.FrequencyPenalty,
|
FrequencyPenalty: a.openaiConf.FrequencyPenalty,
|
||||||
Stream: stream,
|
Stream: stream,
|
||||||
}
|
}
|
||||||
contextList := make([]*chat_context.ChatMessage, 0)
|
tokens, currTokens, req.Messages, err = a.rebuildMessages(currMessage)
|
||||||
if in.EnableContext {
|
|
||||||
//从缓存中获取上下文信息
|
|
||||||
contextList = a.getContext(in.Pid)
|
|
||||||
}
|
|
||||||
//重构req.Messages
|
|
||||||
tokens, currTokens, req.Messages, err = a.rebuildMessages(contextList, currMessage)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.log.Error(err)
|
a.log.Error(err)
|
||||||
return
|
return
|
||||||
@@ -124,51 +101,37 @@ func (a *app) buildChatCompletionRequest(in *proto.ChatCompletionRequest, stream
|
|||||||
req.MaxTokens = a.openaiConf.MaxTokens - tokens
|
req.MaxTokens = a.openaiConf.MaxTokens - tokens
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
func (a *app) rebuildMessages(contextList []*chat_context.ChatMessage, currMessage openai.ChatCompletionMessage) (tokens, currTokens int, messages []openai.ChatCompletionMessage, err error) {
|
|
||||||
var sysMessage openai.ChatCompletionMessage
|
func (a *app) rebuildMessages(currMessage openai.ChatCompletionMessage) (tokens, currTokens int, messages []openai.ChatCompletionMessage, err error) {
|
||||||
|
messages = make([]openai.ChatCompletionMessage, 0, 2)
|
||||||
botTokens := 0
|
botTokens := 0
|
||||||
if a.openaiConf.BotDesc != "" {
|
if a.openaiConf.BotDesc != "" {
|
||||||
sysMessage = openai.ChatCompletionMessage{
|
sysMessage := openai.ChatCompletionMessage{
|
||||||
Role: openai.ChatMessageRoleSystem,
|
Role: openai.ChatMessageRoleSystem,
|
||||||
Content: a.openaiConf.BotDesc,
|
Content: a.openaiConf.BotDesc,
|
||||||
}
|
}
|
||||||
botTokens, err = tokenizer.GetTokens(&sysMessage, a.openaiConf.Model)
|
botTokens, err = tokenizer.GetTokens(&sysMessage, a.openaiConf.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.log.Error(err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
messages = append(messages, sysMessage)
|
||||||
}
|
}
|
||||||
messages = []openai.ChatCompletionMessage{currMessage}
|
|
||||||
currTokens, err = tokenizer.GetTokens(&currMessage, a.openaiConf.Model)
|
currTokens, err = tokenizer.GetTokens(&currMessage, a.openaiConf.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.log.Error(err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if currTokens > a.openaiConf.MaxTokens-a.openaiConf.MinResponseTokens-botTokens-ChatPrimedTokens {
|
if currTokens > a.openaiConf.MaxTokens-a.openaiConf.MinResponseTokens-botTokens-ChatPrimedTokens {
|
||||||
err = zerror.NewByMsg("请求消息超限")
|
return 0, 0, nil, zerror.NewByMsg("请求消息超限")
|
||||||
a.log.Error(err)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tokens = currTokens + botTokens + ChatPrimedTokens
|
tokens = currTokens + botTokens + ChatPrimedTokens
|
||||||
if contextList != nil {
|
messages = append(messages, currMessage)
|
||||||
for _, item := range contextList {
|
|
||||||
if tokens+item.Tokens+ChatPrimedTokens > a.openaiConf.MaxTokens-a.openaiConf.MinResponseTokens {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
messages = append(messages, item.Message)
|
|
||||||
tokens += item.Tokens + ChatPrimedTokens
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 {
|
|
||||||
messages[i], messages[j] = messages[j], messages[i]
|
|
||||||
}
|
|
||||||
if botTokens > 0 {
|
|
||||||
messages = append([]openai.ChatCompletionMessage{sysMessage}, messages...)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *app) buildChatCompletionResponse(msg string) *proto.ChatCompletionResponse {
|
func (a *app) buildChatCompletionResponse(msg string) *proto.ChatCompletionResponse {
|
||||||
res := &proto.ChatCompletionResponse{
|
return &proto.ChatCompletionResponse{
|
||||||
Id: uuid.New().String(),
|
Id: uuid.New().String(),
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: time.Now().Unix(),
|
Created: time.Now().Unix(),
|
||||||
@@ -182,17 +145,12 @@ func (a *app) buildChatCompletionResponse(msg string) *proto.ChatCompletionRespo
|
|||||||
FinishReason: "stop",
|
FinishReason: "stop",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Usage: &proto.Usage{
|
Usage: &proto.Usage{},
|
||||||
PromptTokens: 0,
|
|
||||||
CompletionTokens: 0,
|
|
||||||
TotalTokens: 0,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
return res
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *app) buildChatCompletionStreamResponse(id, delta, finishReason string) *proto.ChatCompletionStreamResponse {
|
func (a *app) buildChatCompletionStreamResponse(id, delta, finishReason string) *proto.ChatCompletionStreamResponse {
|
||||||
res := &proto.ChatCompletionStreamResponse{
|
return &proto.ChatCompletionStreamResponse{
|
||||||
Id: id,
|
Id: id,
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: time.Now().Unix(),
|
Created: time.Now().Unix(),
|
||||||
@@ -208,79 +166,49 @@ func (a *app) buildChatCompletionStreamResponse(id, delta, finishReason string)
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
return res
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *app) buildChatCompletionStreamResponseList(id, msg string) []*proto.ChatCompletionStreamResponse {
|
func (a *app) buildChatCompletionStreamResponseList(id, msg string) []*proto.ChatCompletionStreamResponse {
|
||||||
list := make([]*proto.ChatCompletionStreamResponse, 0)
|
list := make([]*proto.ChatCompletionStreamResponse, 0, len(msg))
|
||||||
for _, delta := range msg {
|
for _, delta := range msg {
|
||||||
list = append(list, a.buildChatCompletionStreamResponse(id, string(delta), ""))
|
list = append(list, a.buildChatCompletionStreamResponse(id, string(delta), ""))
|
||||||
}
|
}
|
||||||
return list
|
return list
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *app) getContext(id string) []*chat_context.ChatMessage {
|
|
||||||
maxLen := a.openaiConf.ContextLen
|
|
||||||
list := make([]*chat_context.ChatMessage, 0, maxLen)
|
|
||||||
key := id
|
|
||||||
for i := 0; i < maxLen; i++ {
|
|
||||||
value, err := a.contextCache.Get(key)
|
|
||||||
if err != nil {
|
|
||||||
a.log.Error(err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if value == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
list = append(list, value)
|
|
||||||
key = value.PID
|
|
||||||
}
|
|
||||||
return list
|
|
||||||
}
|
|
||||||
func (a *app) saveContext(value *chat_context.ChatMessage) error {
|
|
||||||
err := a.contextCache.Set(value.ID, value, a.openaiConf.ContextTTL)
|
|
||||||
if err != nil {
|
|
||||||
a.log.Error(err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (a *app) keywords(in *proto.ChatCompletionRequest) []string {
|
func (a *app) keywords(in *proto.ChatCompletionRequest) []string {
|
||||||
pool := keywords_filter.GetKeywordsClientPool()
|
pool := keywords_filter.GetKeywordsClientPool()
|
||||||
conn := pool.Get()
|
conn := pool.Get()
|
||||||
defer pool.Put(conn)
|
defer pool.Put(conn)
|
||||||
|
|
||||||
accessToken := config.GetConfig().DependOn.Keywords.AccessToken
|
accessToken := config.GetConfig().DependOn.Keywords.AccessToken
|
||||||
client := keywords_proto.NewFilterClient(conn)
|
client := keywords_proto.NewFilterClient(conn)
|
||||||
ctx := services.AppendBearerTokenToContext(context.Background(), accessToken)
|
ctx := services.AppendBearerTokenToContext(context.Background(), accessToken)
|
||||||
req := &keywords_proto.FilterReq{
|
req := &keywords_proto.FilterReq{Text: in.Message}
|
||||||
Text: in.Message,
|
|
||||||
}
|
|
||||||
res, err := client.FindAll(ctx, req)
|
res, err := client.FindAll(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.log.Error(err)
|
a.log.Error(err)
|
||||||
return []string{}
|
return []string{}
|
||||||
}
|
}
|
||||||
return res.Keywords
|
return res.Keywords
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *app) sensitive(in *proto.ChatCompletionRequest) (ok bool, msg string, err error) {
|
func (a *app) sensitive(in *proto.ChatCompletionRequest) (ok bool, msg string, err error) {
|
||||||
pool := keywords_filter.GetSensitiveClientPool()
|
pool := keywords_filter.GetSensitiveClientPool()
|
||||||
conn := pool.Get()
|
conn := pool.Get()
|
||||||
defer pool.Put(conn)
|
defer pool.Put(conn)
|
||||||
|
|
||||||
accessToken := config.GetConfig().DependOn.Sensitive.AccessToken
|
accessToken := config.GetConfig().DependOn.Sensitive.AccessToken
|
||||||
client := keywords_proto.NewFilterClient(conn)
|
client := keywords_proto.NewFilterClient(conn)
|
||||||
ctx := services.AppendBearerTokenToContext(context.Background(), accessToken)
|
ctx := services.AppendBearerTokenToContext(context.Background(), accessToken)
|
||||||
req := &keywords_proto.FilterReq{
|
req := &keywords_proto.FilterReq{Text: in.Message}
|
||||||
Text: in.Message,
|
|
||||||
}
|
|
||||||
res, err := client.Validate(ctx, req)
|
res, err := client.Validate(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.log.Error(err)
|
a.log.Error(err)
|
||||||
return false, "", err
|
return false, "", err
|
||||||
}
|
}
|
||||||
ok = res.Ok
|
if !res.Ok {
|
||||||
if !ok {
|
return false, "触发到了知识盲区,请换个问题再问", nil
|
||||||
msg = "触发到了知识盲区,请换个问题再问"
|
|
||||||
}
|
}
|
||||||
return
|
return true, "", nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,20 +1,18 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
chat_context "ai-chat-service/chat-server/chat-context"
|
|
||||||
"ai-chat-service/chat-server/data"
|
"ai-chat-service/chat-server/data"
|
||||||
metrics_bus "ai-chat-service/chat-server/metrics-bus"
|
metrics_bus "ai-chat-service/chat-server/metrics-bus"
|
||||||
vector_data "ai-chat-service/chat-server/vector-data"
|
|
||||||
"ai-chat-service/pkg/config"
|
"ai-chat-service/pkg/config"
|
||||||
"ai-chat-service/pkg/log"
|
"ai-chat-service/pkg/log"
|
||||||
"ai-chat-service/proto"
|
"ai-chat-service/proto"
|
||||||
|
"ai-chat-service/services/embedding"
|
||||||
|
"ai-chat-service/services/faiss"
|
||||||
"ai-chat-service/services/tokenizer"
|
"ai-chat-service/services/tokenizer"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/golang/protobuf/jsonpb"
|
"github.com/golang/protobuf/jsonpb"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@@ -26,218 +24,127 @@ type chatService struct {
|
|||||||
config *config.Config
|
config *config.Config
|
||||||
log log.ILogger
|
log log.ILogger
|
||||||
data data.IChatRecordsData
|
data data.IChatRecordsData
|
||||||
vectorData vector_data.IChatRecordsData
|
embedder embedding.Embedder
|
||||||
|
faiss faiss.Client
|
||||||
busMetrics *metrics_bus.BusMetrics
|
busMetrics *metrics_bus.BusMetrics
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChatService(data data.IChatRecordsData, vectorData vector_data.IChatRecordsData, config *config.Config, log log.ILogger, busMetrics *metrics_bus.BusMetrics) proto.ChatServer {
|
func NewChatService(data data.IChatRecordsData, embedder embedding.Embedder, faissClient faiss.Client, config *config.Config, log log.ILogger, busMetrics *metrics_bus.BusMetrics) proto.ChatServer {
|
||||||
return &chatService{
|
return &chatService{
|
||||||
config: config,
|
config: config,
|
||||||
log: log,
|
log: log,
|
||||||
data: data,
|
data: data,
|
||||||
vectorData: vectorData,
|
embedder: embedder,
|
||||||
|
faiss: faissClient,
|
||||||
busMetrics: busMetrics,
|
busMetrics: busMetrics,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *chatService) ChatCompletion(ctx context.Context, in *proto.ChatCompletionRequest) (*proto.ChatCompletionResponse, error) {
|
func (s *chatService) ChatCompletion(ctx context.Context, in *proto.ChatCompletionRequest) (*proto.ChatCompletionResponse, error) {
|
||||||
redisContextCache := chat_context.NewRedisCache()
|
app := s.newApp(in)
|
||||||
defer redisContextCache.Close()
|
|
||||||
|
|
||||||
app := s.newApp(in, redisContextCache)
|
|
||||||
//敏感词过滤
|
|
||||||
ok, msg, err := app.sensitive(in)
|
|
||||||
if err != nil {
|
|
||||||
s.log.Error(err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
res := app.buildChatCompletionResponse(msg)
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
//关键词提取
|
|
||||||
keywords := app.keywords(in)
|
|
||||||
if len(keywords) > 0 {
|
|
||||||
idStr, score, err := s.vectorData.QueryData(context.Background(), map[string][]string{"keywords": {strings.Join(keywords, ",")}})
|
|
||||||
if err != nil {
|
|
||||||
s.log.Error(err)
|
|
||||||
} else if score > s.config.Vector.Threshold {
|
|
||||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
s.log.Error(err)
|
|
||||||
} else {
|
|
||||||
record, err := s.data.GetById(id)
|
|
||||||
if err != nil {
|
|
||||||
s.log.Error(err)
|
|
||||||
} else {
|
|
||||||
res := app.buildChatCompletionResponse(record.AIMsg)
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
client := app.getOpenaiClient()
|
|
||||||
req, tokens, currTokens, currMessage, err := app.buildChatCompletionRequest(in, false)
|
|
||||||
resp, err := client.CreateChatCompletion(ctx, req)
|
|
||||||
if err != nil {
|
|
||||||
s.log.Error(err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
res := &proto.ChatCompletionResponse{}
|
|
||||||
bytes, err := json.Marshal(resp)
|
|
||||||
if err != nil {
|
|
||||||
s.log.Error(err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
err = jsonpb.UnmarshalString(string(bytes), res)
|
|
||||||
if err != nil {
|
|
||||||
s.log.Error(err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
reqContext := &chat_context.ChatMessage{
|
|
||||||
ID: in.Id,
|
|
||||||
PID: in.Pid,
|
|
||||||
Message: currMessage,
|
|
||||||
Tokens: currTokens,
|
|
||||||
}
|
|
||||||
err := app.saveContext(reqContext)
|
|
||||||
if err != nil {
|
|
||||||
s.log.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
resContext := &chat_context.ChatMessage{
|
|
||||||
ID: resp.ID,
|
|
||||||
PID: reqContext.ID,
|
|
||||||
Message: resp.Choices[0].Message,
|
|
||||||
Tokens: resp.Usage.CompletionTokens,
|
|
||||||
}
|
|
||||||
err = app.saveContext(resContext)
|
|
||||||
if err != nil {
|
|
||||||
s.log.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
records := &data.ChatRecord{
|
|
||||||
UserMsg: in.Message,
|
|
||||||
UserMsgTokens: currTokens,
|
|
||||||
UserMsgKeywords: keywords,
|
|
||||||
AIMsg: resp.Choices[0].Message.Content,
|
|
||||||
AIMsgTokens: resp.Usage.CompletionTokens,
|
|
||||||
ReqTokens: tokens,
|
|
||||||
CreateAt: time.Now().Unix(),
|
|
||||||
}
|
|
||||||
err := s.data.Add(records)
|
|
||||||
if err != nil {
|
|
||||||
s.log.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
//保存到向量数据库
|
|
||||||
if len(keywords) > 0 {
|
|
||||||
list := []*vector_data.ChatRecord{
|
|
||||||
{
|
|
||||||
ID: strconv.FormatInt(records.ID, 10),
|
|
||||||
KVs: map[string]string{
|
|
||||||
"keywords": strings.Join(keywords, ","),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
err = s.vectorData.UpsertData(context.Background(), list)
|
|
||||||
if err != nil {
|
|
||||||
s.log.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return res, err
|
|
||||||
}
|
|
||||||
func (s *chatService) ChatCompletionStream(in *proto.ChatCompletionRequest, stream proto.Chat_ChatCompletionStreamServer) error {
|
|
||||||
redisContextCache := chat_context.NewRedisCache()
|
|
||||||
defer redisContextCache.Close()
|
|
||||||
|
|
||||||
app := s.newApp(in, redisContextCache)
|
|
||||||
//敏感词过滤
|
|
||||||
ok, msg, err := app.sensitive(in)
|
ok, msg, err := app.sensitive(in)
|
||||||
|
if err != nil {
|
||||||
|
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
s.busMetrics.SensitiveQuestionsTotalCounter.Inc()
|
||||||
|
return app.buildChatCompletionResponse(msg), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
keywords := app.keywords(in)
|
||||||
|
if len(keywords) > 0 {
|
||||||
|
s.busMetrics.KeywordsQuestionsTotalCounter.Inc()
|
||||||
|
}
|
||||||
|
|
||||||
|
req, _, _, _, err := app.buildChatCompletionRequest(in, false)
|
||||||
|
if err != nil {
|
||||||
|
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
questionEmbedding, cachedRecord := s.searchCachedAnswer(ctx, in.Message)
|
||||||
|
if cachedRecord != nil {
|
||||||
|
return app.buildChatCompletionResponse(cachedRecord.Answer), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
client := app.getOpenaiClient()
|
||||||
|
resp, err := client.CreateChatCompletion(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||||
s.log.Error(err)
|
s.log.Error(err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
res := &proto.ChatCompletionResponse{}
|
||||||
|
bytes, err := json.Marshal(resp)
|
||||||
|
if err != nil {
|
||||||
|
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err = jsonpb.UnmarshalString(string(bytes), res); err != nil {
|
||||||
|
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resp.Choices) > 0 {
|
||||||
|
if err = s.persistQA(ctx, questionEmbedding, in.Message, resp.Choices[0].Message.Content); err != nil {
|
||||||
|
s.log.Error(err)
|
||||||
|
} else {
|
||||||
|
s.busMetrics.QuestionsTotalCounter.Inc()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *chatService) ChatCompletionStream(in *proto.ChatCompletionRequest, stream proto.Chat_ChatCompletionStreamServer) error {
|
||||||
|
app := s.newApp(in)
|
||||||
|
|
||||||
|
ok, msg, err := app.sensitive(in)
|
||||||
|
if err != nil {
|
||||||
|
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
s.busMetrics.SensitiveQuestionsTotalCounter.Inc()
|
s.busMetrics.SensitiveQuestionsTotalCounter.Inc()
|
||||||
resId := uuid.New().String()
|
resID := uuid.New().String()
|
||||||
startRes := app.buildChatCompletionStreamResponse(resId, "", "")
|
if err = stream.Send(app.buildChatCompletionStreamResponse(resID, "", "")); err != nil {
|
||||||
endRes := app.buildChatCompletionStreamResponse(resId, "", "stop")
|
|
||||||
err = stream.Send(startRes)
|
|
||||||
if err != nil {
|
|
||||||
s.log.Error(err)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
resList := app.buildChatCompletionStreamResponseList(resId, msg)
|
for _, res := range app.buildChatCompletionStreamResponseList(resID, msg) {
|
||||||
for _, res := range resList {
|
if err = stream.Send(res); err != nil {
|
||||||
err = stream.Send(res)
|
|
||||||
if err != nil {
|
|
||||||
s.log.Error(err)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = stream.Send(endRes)
|
return stream.Send(app.buildChatCompletionStreamResponse(resID, "", "stop"))
|
||||||
if err != nil {
|
|
||||||
s.log.Error(err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//关键词提取
|
|
||||||
keywords := app.keywords(in)
|
keywords := app.keywords(in)
|
||||||
|
|
||||||
if len(keywords) > 0 {
|
if len(keywords) > 0 {
|
||||||
s.busMetrics.KeywordsQuestionsTotalCounter.Inc()
|
s.busMetrics.KeywordsQuestionsTotalCounter.Inc()
|
||||||
idStr, score, err := s.vectorData.QueryData(context.Background(), map[string][]string{"keywords": {strings.Join(keywords, ",")}})
|
}
|
||||||
if err != nil {
|
|
||||||
s.log.Error(err)
|
req, _, _, _, err := app.buildChatCompletionRequest(in, true)
|
||||||
} else if score > s.config.Vector.Threshold {
|
if err != nil {
|
||||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||||
if err != nil {
|
return err
|
||||||
s.log.Error(err)
|
}
|
||||||
} else {
|
|
||||||
record, err := s.data.GetById(id)
|
questionEmbedding, cachedRecord := s.searchCachedAnswer(stream.Context(), in.Message)
|
||||||
if err != nil {
|
if cachedRecord != nil {
|
||||||
s.log.Error(err)
|
if err = stream.Send(app.buildChatCompletionStreamResponse(cachedRecord.ID, "", "")); err != nil {
|
||||||
} else {
|
return err
|
||||||
resId := uuid.New().String()
|
}
|
||||||
startRes := app.buildChatCompletionStreamResponse(resId, "", "")
|
for _, res := range app.buildChatCompletionStreamResponseList(cachedRecord.ID, cachedRecord.Answer) {
|
||||||
endRes := app.buildChatCompletionStreamResponse(resId, "", "stop")
|
if err = stream.Send(res); err != nil {
|
||||||
err = stream.Send(startRes)
|
return err
|
||||||
if err != nil {
|
|
||||||
s.log.Error(err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
resList := app.buildChatCompletionStreamResponseList(resId, record.AIMsg)
|
|
||||||
for _, res := range resList {
|
|
||||||
err = stream.Send(res)
|
|
||||||
if err != nil {
|
|
||||||
s.log.Error(err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err = stream.Send(endRes)
|
|
||||||
if err != nil {
|
|
||||||
s.log.Error(err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return stream.Send(app.buildChatCompletionStreamResponse(cachedRecord.ID, "", "stop"))
|
||||||
}
|
}
|
||||||
|
|
||||||
client := app.getOpenaiClient()
|
client := app.getOpenaiClient()
|
||||||
req, tokens, currTokens, currMessage, err := app.buildChatCompletionRequest(in, false)
|
|
||||||
chatStream, err := client.CreateChatCompletionStream(stream.Context(), req)
|
chatStream, err := client.CreateChatCompletionStream(stream.Context(), req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||||
@@ -245,109 +152,106 @@ func (s *chatService) ChatCompletionStream(in *proto.ChatCompletionRequest, stre
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer chatStream.Close()
|
defer chatStream.Close()
|
||||||
|
|
||||||
completionContent := ""
|
completionContent := ""
|
||||||
resultID := ""
|
|
||||||
for {
|
for {
|
||||||
resp, err := chatStream.Recv()
|
resp, err := chatStream.Recv()
|
||||||
if err != nil && err != io.EOF {
|
if err != nil && err != io.EOF {
|
||||||
|
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||||
s.log.Error(err)
|
s.log.Error(err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if resultID == "" {
|
|
||||||
resultID = resp.ID
|
|
||||||
}
|
|
||||||
completionContent += resp.Choices[0].Delta.Content
|
completionContent += resp.Choices[0].Delta.Content
|
||||||
|
|
||||||
res := &proto.ChatCompletionStreamResponse{}
|
res := &proto.ChatCompletionStreamResponse{}
|
||||||
bytes, err := json.Marshal(resp)
|
bytes, err := json.Marshal(resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.log.Error(err)
|
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = jsonpb.UnmarshalString(string(bytes), res)
|
if err = jsonpb.UnmarshalString(string(bytes), res); err != nil {
|
||||||
if err != nil {
|
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||||
s.log.Error(err)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = stream.Send(res)
|
if err = stream.Send(res); err != nil {
|
||||||
if err != nil {
|
|
||||||
s.log.Error(err)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
model := s.config.Chat.Model
|
||||||
|
if in.ChatParam != nil && in.ChatParam.Model != "" {
|
||||||
|
model = in.ChatParam.Model
|
||||||
|
}
|
||||||
resultMessage := openai.ChatCompletionMessage{
|
resultMessage := openai.ChatCompletionMessage{
|
||||||
Role: openai.ChatMessageRoleAssistant,
|
Role: openai.ChatMessageRoleAssistant,
|
||||||
Content: completionContent,
|
Content: completionContent,
|
||||||
}
|
}
|
||||||
model := s.config.Chat.Model
|
if _, err = tokenizer.GetTokens(&resultMessage, model); err != nil {
|
||||||
if in.ChatParam != nil && in.ChatParam.Model != "" {
|
|
||||||
model = in.ChatParam.Model
|
|
||||||
}
|
|
||||||
resultTokens, err := tokenizer.GetTokens(&resultMessage, model)
|
|
||||||
if err != nil {
|
|
||||||
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||||
s.log.Error(err)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
if err = s.persistQA(stream.Context(), questionEmbedding, in.Message, completionContent); err != nil {
|
||||||
reqContext := &chat_context.ChatMessage{
|
s.log.Error(err)
|
||||||
ID: in.Id,
|
} else {
|
||||||
PID: in.Pid,
|
|
||||||
Message: currMessage,
|
|
||||||
Tokens: currTokens,
|
|
||||||
}
|
|
||||||
err := app.saveContext(reqContext)
|
|
||||||
if err != nil {
|
|
||||||
s.log.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
resContext := &chat_context.ChatMessage{
|
|
||||||
ID: resultID,
|
|
||||||
PID: reqContext.ID,
|
|
||||||
Message: resultMessage,
|
|
||||||
Tokens: resultTokens,
|
|
||||||
}
|
|
||||||
err = app.saveContext(resContext)
|
|
||||||
if err != nil {
|
|
||||||
s.log.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
s.busMetrics.QuestionsTotalCounter.Inc()
|
s.busMetrics.QuestionsTotalCounter.Inc()
|
||||||
records := &data.ChatRecord{
|
}
|
||||||
UserMsg: in.Message,
|
|
||||||
UserMsgTokens: currTokens,
|
|
||||||
UserMsgKeywords: keywords,
|
|
||||||
AIMsg: completionContent,
|
|
||||||
AIMsgTokens: resultTokens,
|
|
||||||
ReqTokens: tokens,
|
|
||||||
CreateAt: time.Now().Unix(),
|
|
||||||
}
|
|
||||||
err := s.data.Add(records)
|
|
||||||
if err != nil {
|
|
||||||
s.log.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
//保存到向量数据库
|
|
||||||
if len(keywords) > 0 {
|
|
||||||
list := []*vector_data.ChatRecord{
|
|
||||||
{
|
|
||||||
ID: strconv.FormatInt(records.ID, 10),
|
|
||||||
KVs: map[string]string{
|
|
||||||
"keywords": strings.Join(keywords, ","),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
err = s.vectorData.UpsertData(context.Background(), list)
|
|
||||||
if err != nil {
|
|
||||||
s.log.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *chatService) searchCachedAnswer(ctx context.Context, question string) ([]float32, *data.ChatRecord) {
|
||||||
|
embeddingVector, err := s.embedder.Embed(ctx, question)
|
||||||
|
if err != nil {
|
||||||
|
s.log.Error(err)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
searchRes, err := s.faiss.Search(ctx, embeddingVector, s.config.Faiss.SearchK)
|
||||||
|
if err != nil {
|
||||||
|
s.log.Error(err)
|
||||||
|
return embeddingVector, nil
|
||||||
|
}
|
||||||
|
if searchRes == nil || len(searchRes.IDs) == 0 || len(searchRes.SimilarityScores) == 0 {
|
||||||
|
return embeddingVector, nil
|
||||||
|
}
|
||||||
|
limit := len(searchRes.IDs)
|
||||||
|
if len(searchRes.SimilarityScores) < limit {
|
||||||
|
limit = len(searchRes.SimilarityScores)
|
||||||
|
}
|
||||||
|
for i := 0; i < limit; i++ {
|
||||||
|
if searchRes.IDs[i] < 0 || searchRes.SimilarityScores[i] < s.config.Faiss.SimilarityThreshold {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
record, err := s.data.GetById(strconv.FormatInt(searchRes.IDs[i], 10))
|
||||||
|
if err != nil {
|
||||||
|
s.log.Error(err)
|
||||||
|
return embeddingVector, nil
|
||||||
|
}
|
||||||
|
if record != nil {
|
||||||
|
return embeddingVector, record
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return embeddingVector, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *chatService) persistQA(ctx context.Context, questionEmbedding []float32, question, answer string) error {
|
||||||
|
if len(questionEmbedding) == 0 {
|
||||||
|
vector, err := s.embedder.Embed(ctx, question)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
questionEmbedding = vector
|
||||||
|
}
|
||||||
|
id, err := s.faiss.Insert(ctx, questionEmbedding)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return s.data.Add(&data.ChatRecord{
|
||||||
|
ID: id,
|
||||||
|
Question: question,
|
||||||
|
Answer: answer,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,29 +0,0 @@
|
|||||||
package vector_data
|
|
||||||
|
|
||||||
import (
|
|
||||||
"ai-chat-service/pkg/config"
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
const CHAT_RECORDS = "chat_records"
|
|
||||||
|
|
||||||
type ChatRecord struct {
|
|
||||||
ID string
|
|
||||||
KVs map[string]string
|
|
||||||
}
|
|
||||||
type IChatRecordsData interface {
|
|
||||||
UpsertData(ctx context.Context, list []*ChatRecord) error
|
|
||||||
QueryData(ctx context.Context, text map[string][]string) (id string, score float32, err error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewChatRecordsData(config *config.Config) (IChatRecordsData, error) {
|
|
||||||
switch config.Vector.Provider {
|
|
||||||
case "tencent", "":
|
|
||||||
return newTencentChatRecordsData(config)
|
|
||||||
case "pgvector":
|
|
||||||
return newPgvectorChatRecordsData(config)
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unsupported vector provider: %s", config.Vector.Provider)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,121 +0,0 @@
|
|||||||
package vector_data
|
|
||||||
|
|
||||||
import (
|
|
||||||
"ai-chat-service/pkg/config"
|
|
||||||
"ai-chat-service/services/embedding"
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx"
|
|
||||||
)
|
|
||||||
|
|
||||||
type pgvectorChatRecordsData struct {
|
|
||||||
config *config.Config
|
|
||||||
pool *pgx.ConnPool
|
|
||||||
embedder embedding.Embedder
|
|
||||||
}
|
|
||||||
|
|
||||||
func newPgvectorChatRecordsData(config *config.Config) (IChatRecordsData, error) {
|
|
||||||
connConfig, err := pgx.ParseConnectionString(config.Vector.Pgvector.DSN)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
pool, err := pgx.NewConnPool(pgx.ConnPoolConfig{
|
|
||||||
ConnConfig: connConfig,
|
|
||||||
MaxConnections: config.Vector.Pgvector.MaxOpenConn,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
embedder, err := embedding.NewEmbedder(config)
|
|
||||||
if err != nil {
|
|
||||||
pool.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &pgvectorChatRecordsData{
|
|
||||||
config: config,
|
|
||||||
pool: pool,
|
|
||||||
embedder: embedder,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (data *pgvectorChatRecordsData) UpsertData(ctx context.Context, list []*ChatRecord) error {
|
|
||||||
table := data.config.Vector.Pgvector.Table
|
|
||||||
if table == "" {
|
|
||||||
table = "chat_record_vectors"
|
|
||||||
}
|
|
||||||
for _, item := range list {
|
|
||||||
recordID, err := strconv.ParseInt(item.ID, 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
keywordsText := embedding.BuildText(item.KVs["keywords"])
|
|
||||||
if keywordsText == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
vector, err := data.embedder.Embed(ctx, keywordsText)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
commandTag, err := data.pool.Exec(
|
|
||||||
fmt.Sprintf(
|
|
||||||
"INSERT INTO %s (record_id, keywords_text, embedding, created_at) VALUES ($1, $2, $3::vector, $4) ON CONFLICT (record_id) DO UPDATE SET keywords_text = EXCLUDED.keywords_text, embedding = EXCLUDED.embedding, created_at = EXCLUDED.created_at",
|
|
||||||
table,
|
|
||||||
),
|
|
||||||
recordID,
|
|
||||||
keywordsText,
|
|
||||||
vectorLiteral(vector),
|
|
||||||
time.Now().Unix(),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if commandTag.RowsAffected() == 0 {
|
|
||||||
return fmt.Errorf("pgvector upsert affected 0 rows for record_id=%d", recordID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (data *pgvectorChatRecordsData) QueryData(ctx context.Context, text map[string][]string) (id string, score float32, err error) {
|
|
||||||
keywordsText := embedding.BuildText(text["keywords"]...)
|
|
||||||
if keywordsText == "" {
|
|
||||||
return "", 0, nil
|
|
||||||
}
|
|
||||||
vector, err := data.embedder.Embed(ctx, keywordsText)
|
|
||||||
if err != nil {
|
|
||||||
return "", 0, err
|
|
||||||
}
|
|
||||||
table := data.config.Vector.Pgvector.Table
|
|
||||||
if table == "" {
|
|
||||||
table = "chat_record_vectors"
|
|
||||||
}
|
|
||||||
var recordID int64
|
|
||||||
err = data.pool.QueryRowEx(
|
|
||||||
ctx,
|
|
||||||
fmt.Sprintf(
|
|
||||||
"SELECT record_id, CAST(1 - (embedding <=> $1::vector) AS real) AS score FROM %s ORDER BY embedding <=> $1::vector LIMIT 1",
|
|
||||||
table,
|
|
||||||
),
|
|
||||||
nil,
|
|
||||||
vectorLiteral(vector),
|
|
||||||
).Scan(&recordID, &score)
|
|
||||||
if err != nil {
|
|
||||||
if err == pgx.ErrNoRows {
|
|
||||||
return "", 0, nil
|
|
||||||
}
|
|
||||||
return "", 0, err
|
|
||||||
}
|
|
||||||
return strconv.FormatInt(recordID, 10), score, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func vectorLiteral(values []float32) string {
|
|
||||||
parts := make([]string, 0, len(values))
|
|
||||||
for _, value := range values {
|
|
||||||
parts = append(parts, strconv.FormatFloat(float64(value), 'f', -1, 32))
|
|
||||||
}
|
|
||||||
return "[" + strings.Join(parts, ",") + "]"
|
|
||||||
}
|
|
||||||
@@ -1,66 +0,0 @@
|
|||||||
package vector_data
|
|
||||||
|
|
||||||
import (
|
|
||||||
"ai-chat-service/pkg/config"
|
|
||||||
"context"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/tencent/vectordatabase-sdk-go/tcvectordb"
|
|
||||||
)
|
|
||||||
|
|
||||||
type tencentChatRecordsData struct {
|
|
||||||
config *config.Config
|
|
||||||
vectorDB *tcvectordb.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTencentChatRecordsData(config *config.Config) (IChatRecordsData, error) {
|
|
||||||
option := &tcvectordb.ClientOption{
|
|
||||||
Timeout: time.Second * time.Duration(config.Vector.Tencent.Timeout),
|
|
||||||
MaxIdldConnPerHost: config.Vector.Tencent.MaxIdleConnPerHost,
|
|
||||||
IdleConnTimeout: time.Second * time.Duration(config.Vector.Tencent.IdleConnTimeout),
|
|
||||||
ReadConsistency: tcvectordb.ReadConsistency(config.Vector.Tencent.ReadConsistency),
|
|
||||||
}
|
|
||||||
client, err := tcvectordb.NewClient(config.Vector.Tencent.Url, config.Vector.Tencent.Username, config.Vector.Tencent.Pwd, option)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &tencentChatRecordsData{
|
|
||||||
config: config,
|
|
||||||
vectorDB: client,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (data *tencentChatRecordsData) UpsertData(ctx context.Context, list []*ChatRecord) error {
|
|
||||||
database := data.config.Vector.Tencent.Database
|
|
||||||
collection := CHAT_RECORDS
|
|
||||||
coll := data.vectorDB.Database(database).Collection(collection)
|
|
||||||
documentList := make([]tcvectordb.Document, 0, len(list))
|
|
||||||
for _, l := range list {
|
|
||||||
doc := tcvectordb.Document{Id: l.ID}
|
|
||||||
doc.Fields = make(map[string]tcvectordb.Field, len(l.KVs))
|
|
||||||
for k, v := range l.KVs {
|
|
||||||
doc.Fields[k] = tcvectordb.Field{Val: v}
|
|
||||||
}
|
|
||||||
documentList = append(documentList, doc)
|
|
||||||
}
|
|
||||||
_, err := coll.Upsert(ctx, documentList)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (data *tencentChatRecordsData) QueryData(ctx context.Context, text map[string][]string) (id string, score float32, err error) {
|
|
||||||
database := data.config.Vector.Tencent.Database
|
|
||||||
collection := CHAT_RECORDS
|
|
||||||
coll := data.vectorDB.Database(database).Collection(collection)
|
|
||||||
result, err := coll.SearchByText(ctx, text, &tcvectordb.SearchDocumentParams{
|
|
||||||
Params: &tcvectordb.SearchDocParams{Ef: 100},
|
|
||||||
Limit: 1,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return "", 0, err
|
|
||||||
}
|
|
||||||
if len(result.Documents) > 0 && len(result.Documents[0]) > 0 {
|
|
||||||
doc := result.Documents[0][0]
|
|
||||||
return doc.Id, doc.Score, nil
|
|
||||||
}
|
|
||||||
return "", 0, nil
|
|
||||||
}
|
|
||||||
@@ -32,19 +32,10 @@ chat:
|
|||||||
bot_desc: "你是一个AI助手,我需要你模拟一名资深的软件工程师来回答我的问题"
|
bot_desc: "你是一个AI助手,我需要你模拟一名资深的软件工程师来回答我的问题"
|
||||||
# 单次请求,保留的响应tokens数量
|
# 单次请求,保留的响应tokens数量
|
||||||
min_response_tokens: 2048
|
min_response_tokens: 2048
|
||||||
# 上下文缓存时长,单位s
|
|
||||||
context_ttl: 1800
|
|
||||||
# 上下文消息条数
|
|
||||||
context_len: 4
|
|
||||||
redis:
|
redis:
|
||||||
host: "127.0.0.1"
|
host: "127.0.0.1"
|
||||||
port: 8888
|
port: 8888
|
||||||
pwd: "123456"
|
pwd: "123456"
|
||||||
mysql:
|
|
||||||
dsn: "root:root@tcp(127.0.0.1:3306)/ai_chat?collation=utf8mb4_unicode_ci&charset=utf8mb4"
|
|
||||||
maxLifeTime: 3600
|
|
||||||
maxOpenConn: 10
|
|
||||||
maxIdleConn: 10
|
|
||||||
dependOn:
|
dependOn:
|
||||||
sensitive:
|
sensitive:
|
||||||
address: "localhost:50053"
|
address: "localhost:50053"
|
||||||
@@ -54,49 +45,16 @@ dependOn:
|
|||||||
accessToken: "ang1chubdev1ozhome256487d22sapguuv1ozhom"
|
accessToken: "ang1chubdev1ozhome256487d22sapguuv1ozhom"
|
||||||
tokenizer:
|
tokenizer:
|
||||||
address: "http://127.0.0.1:3002"
|
address: "http://127.0.0.1:3002"
|
||||||
vector:
|
|
||||||
# 向量后端:tencent / pgvector
|
|
||||||
provider: "pgvector"
|
|
||||||
# 历史问答命中阈值
|
|
||||||
threshold: 0.99
|
|
||||||
tencent:
|
|
||||||
url: "http://lb-4u4r1fk4-1ys6gv3rpmdan420.clb.ap-guangzhou.tencentclb.com:60000"
|
|
||||||
username: "root"
|
|
||||||
pwd: "YaUfVueWZJ20e4ghyLlBT8Dou5OapwpFTUq50oft"
|
|
||||||
database: "ai-chat"
|
|
||||||
timeout: 5
|
|
||||||
maxIdleConnPerHost: 2
|
|
||||||
readConsistency: "eventualConsistency"
|
|
||||||
idleConnTimeout: 60
|
|
||||||
pgvector:
|
|
||||||
dsn: "postgres://postgres:postgres@127.0.0.1:15432/ai_chat?sslmode=disable"
|
|
||||||
table: "chat_record_vectors"
|
|
||||||
dimensions: 1024
|
|
||||||
maxLifeTime: 3600
|
|
||||||
maxOpenConn: 10
|
|
||||||
maxIdleConn: 10
|
|
||||||
embedding:
|
embedding:
|
||||||
provider: "openai-compatible"
|
provider: "openai-compatible"
|
||||||
# 智谱 OpenAI 兼容网关;可被项目根目录 .env 覆盖
|
# 智谱 OpenAI 兼容网关;可被项目根目录 .env 覆盖
|
||||||
base_url: "https://open.bigmodel.cn/api/paas/v4"
|
base_url: "https://open.bigmodel.cn/api/paas/v4"
|
||||||
# 默认故意设成错误值,真实 key 请放到项目根目录 .env
|
# 默认故意设成错误值,真实 key 请放到项目根目录 .env
|
||||||
api_key: "__INVALID_SET_AI_CHAT_EMBEDDING_API_KEY__"
|
api_key: "__INVALID_SET_AI_CHAT_EMBEDDING_API_KEY__"
|
||||||
# embedding-2 固定 1024 维,和当前 pgvector 表结构一致
|
|
||||||
model: "embedding-2"
|
model: "embedding-2"
|
||||||
timeout: 10
|
timeout: 10
|
||||||
vectorDB:
|
faiss:
|
||||||
# 访问地址
|
base_url: "http://127.0.0.1:8451"
|
||||||
url: "http://lb-4u4r1fk4-1ys6gv3rpmdan420.clb.ap-guangzhou.tencentclb.com:60000"
|
search_k: 1
|
||||||
# 用户名
|
similarity_threshold: 0.9
|
||||||
username: "root"
|
timeout: 10
|
||||||
# 密码
|
|
||||||
pwd: "YaUfVueWZJ20e4ghyLlBT8Dou5OapwpFTUq50oft"
|
|
||||||
database: "ai-chat"
|
|
||||||
# 请求超时时长s
|
|
||||||
timeout: 5
|
|
||||||
# 最大空闲连接数
|
|
||||||
maxIdleConnPerHost: 2
|
|
||||||
# 读一致性: strongConsistency(强一致性),eventualConsistency(最终一致性)
|
|
||||||
readConsistency: "eventualConsistency"
|
|
||||||
# 空闲连接超时时长s
|
|
||||||
idleConnTimeout: 60
|
|
||||||
|
|||||||
@@ -16,17 +16,10 @@ chat:
|
|||||||
frequency_penalty: 0
|
frequency_penalty: 0
|
||||||
bot_desc: "你是一个AI助手,我需要你模拟一名资深的软件工程师来回答我的问题"
|
bot_desc: "你是一个AI助手,我需要你模拟一名资深的软件工程师来回答我的问题"
|
||||||
min_response_tokens: 600
|
min_response_tokens: 600
|
||||||
context_ttl: 1800
|
|
||||||
context_len: 4
|
|
||||||
redis:
|
redis:
|
||||||
host: "host.docker.internal"
|
host: "host.docker.internal"
|
||||||
port: 8888
|
port: 8888
|
||||||
pwd: "123456"
|
pwd: "123456"
|
||||||
mysql:
|
|
||||||
dsn: "root:root@tcp(mysql:3306)/ai_chat?collation=utf8mb4_unicode_ci&charset=utf8mb4"
|
|
||||||
maxLifeTime: 3600
|
|
||||||
maxOpenConn: 10
|
|
||||||
maxIdleConn: 10
|
|
||||||
dependOn:
|
dependOn:
|
||||||
sensitive:
|
sensitive:
|
||||||
address: "sensitive-filter:50053"
|
address: "sensitive-filter:50053"
|
||||||
@@ -36,19 +29,14 @@ dependOn:
|
|||||||
accessToken: "ang1chubdev1ozhome256487d22sapguuv1ozhom"
|
accessToken: "ang1chubdev1ozhome256487d22sapguuv1ozhom"
|
||||||
tokenizer:
|
tokenizer:
|
||||||
address: "http://tokenizer:3002"
|
address: "http://tokenizer:3002"
|
||||||
vector:
|
|
||||||
provider: "pgvector"
|
|
||||||
threshold: 0.99
|
|
||||||
pgvector:
|
|
||||||
dsn: "postgres://postgres:postgres@pgvector:5432/ai_chat?sslmode=disable"
|
|
||||||
table: "chat_record_vectors"
|
|
||||||
dimensions: 1024
|
|
||||||
maxLifeTime: 3600
|
|
||||||
maxOpenConn: 10
|
|
||||||
maxIdleConn: 10
|
|
||||||
embedding:
|
embedding:
|
||||||
provider: "openai-compatible"
|
provider: "openai-compatible"
|
||||||
base_url: "https://open.bigmodel.cn/api/paas/v4"
|
base_url: "https://open.bigmodel.cn/api/paas/v4"
|
||||||
api_key: "d51b903546814cc9981d3649a4a899a3.NQOtz3ocRtQwimh9"
|
api_key: "d51b903546814cc9981d3649a4a899a3.NQOtz3ocRtQwimh9"
|
||||||
model: "embedding-2"
|
model: "embedding-2"
|
||||||
timeout: 10
|
timeout: 10
|
||||||
|
faiss:
|
||||||
|
base_url: "http://host.docker.internal:8451"
|
||||||
|
search_k: 1
|
||||||
|
similarity_threshold: 0.9
|
||||||
|
timeout: 10
|
||||||
|
|||||||
@@ -3,58 +3,43 @@ module ai-chat-service
|
|||||||
go 1.25.0
|
go 1.25.0
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/go-sql-driver/mysql v1.8.1
|
|
||||||
github.com/golang/protobuf v1.5.4
|
github.com/golang/protobuf v1.5.4
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/jackc/pgx v3.6.2+incompatible
|
|
||||||
github.com/prometheus/client_golang v1.20.4
|
github.com/prometheus/client_golang v1.20.4
|
||||||
github.com/redis/go-redis/v9 v9.6.1
|
github.com/redis/go-redis/v9 v9.6.1
|
||||||
github.com/sashabaranov/go-openai v1.9.4
|
github.com/sashabaranov/go-openai v1.9.4
|
||||||
github.com/sirupsen/logrus v1.9.3
|
github.com/sirupsen/logrus v1.9.3
|
||||||
github.com/spf13/viper v1.19.0
|
github.com/spf13/viper v1.19.0
|
||||||
github.com/tencent/vectordatabase-sdk-go v1.3.5
|
|
||||||
google.golang.org/grpc v1.65.0
|
google.golang.org/grpc v1.65.0
|
||||||
google.golang.org/protobuf v1.34.2
|
google.golang.org/protobuf v1.34.2
|
||||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
filippo.io/edwards25519 v1.1.0 // indirect
|
|
||||||
github.com/beorn7/perks v1.0.1 // indirect
|
github.com/beorn7/perks v1.0.1 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
github.com/clbanning/mxj v1.8.4 // indirect
|
|
||||||
github.com/cockroachdb/apd v1.1.0 // indirect
|
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
github.com/fsnotify/fsnotify v1.7.0 // indirect
|
github.com/fsnotify/fsnotify v1.7.0 // indirect
|
||||||
github.com/gofrs/uuid v4.4.0+incompatible // indirect
|
|
||||||
github.com/google/go-querystring v1.0.0 // indirect
|
|
||||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||||
github.com/jackc/fake v0.0.0-20150926172116-812a484cc733 // indirect
|
|
||||||
github.com/klauspost/compress v1.17.9 // indirect
|
github.com/klauspost/compress v1.17.9 // indirect
|
||||||
github.com/lib/pq v1.12.3 // indirect
|
|
||||||
github.com/magiconair/properties v1.8.7 // indirect
|
github.com/magiconair/properties v1.8.7 // indirect
|
||||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||||
github.com/mozillazg/go-httpheader v0.2.1 // indirect
|
|
||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||||
github.com/pkg/errors v0.9.1 // indirect
|
|
||||||
github.com/prometheus/client_model v0.6.1 // indirect
|
github.com/prometheus/client_model v0.6.1 // indirect
|
||||||
github.com/prometheus/common v0.55.0 // indirect
|
github.com/prometheus/common v0.55.0 // indirect
|
||||||
github.com/prometheus/procfs v0.15.1 // indirect
|
github.com/prometheus/procfs v0.15.1 // indirect
|
||||||
github.com/sagikazarmark/locafero v0.4.0 // indirect
|
github.com/sagikazarmark/locafero v0.4.0 // indirect
|
||||||
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
|
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
|
||||||
github.com/shopspring/decimal v1.4.0 // indirect
|
|
||||||
github.com/sourcegraph/conc v0.3.0 // indirect
|
github.com/sourcegraph/conc v0.3.0 // indirect
|
||||||
github.com/spf13/afero v1.11.0 // indirect
|
github.com/spf13/afero v1.11.0 // indirect
|
||||||
github.com/spf13/cast v1.6.0 // indirect
|
github.com/spf13/cast v1.6.0 // indirect
|
||||||
github.com/spf13/pflag v1.0.5 // indirect
|
github.com/spf13/pflag v1.0.5 // indirect
|
||||||
github.com/stretchr/testify v1.11.1 // indirect
|
github.com/stretchr/testify v1.11.1 // indirect
|
||||||
github.com/subosito/gotenv v1.6.0 // indirect
|
github.com/subosito/gotenv v1.6.0 // indirect
|
||||||
github.com/tencentyun/cos-go-sdk-v5 v0.7.54 // indirect
|
|
||||||
go.uber.org/atomic v1.9.0 // indirect
|
go.uber.org/atomic v1.9.0 // indirect
|
||||||
go.uber.org/multierr v1.9.0 // indirect
|
go.uber.org/multierr v1.9.0 // indirect
|
||||||
golang.org/x/crypto v0.24.0 // indirect
|
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
|
||||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect
|
|
||||||
golang.org/x/net v0.26.0 // indirect
|
golang.org/x/net v0.26.0 // indirect
|
||||||
golang.org/x/sys v0.22.0 // indirect
|
golang.org/x/sys v0.22.0 // indirect
|
||||||
golang.org/x/text v0.29.0 // indirect
|
golang.org/x/text v0.29.0 // indirect
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -28,14 +28,6 @@ type Config struct {
|
|||||||
FrequencyPenalty float32 `mapstructure:"frequency_penalty"`
|
FrequencyPenalty float32 `mapstructure:"frequency_penalty"`
|
||||||
BotDesc string `mapstructure:"bot_desc"`
|
BotDesc string `mapstructure:"bot_desc"`
|
||||||
MinResponseTokens int `mapstructure:"min_response_tokens"`
|
MinResponseTokens int `mapstructure:"min_response_tokens"`
|
||||||
ContextTTL int `mapstructure:"context_ttl"`
|
|
||||||
ContextLen int `mapstructure:"context_len"`
|
|
||||||
}
|
|
||||||
Mysql struct {
|
|
||||||
DSN string
|
|
||||||
MaxLifeTime int
|
|
||||||
MaxOpenConn int
|
|
||||||
MaxIdleConn int
|
|
||||||
}
|
}
|
||||||
Redis struct {
|
Redis struct {
|
||||||
Host string
|
Host string
|
||||||
@@ -55,28 +47,6 @@ type Config struct {
|
|||||||
Address string
|
Address string
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Vector struct {
|
|
||||||
Provider string
|
|
||||||
Threshold float32
|
|
||||||
Tencent struct {
|
|
||||||
Url string
|
|
||||||
Username string
|
|
||||||
Pwd string
|
|
||||||
Database string
|
|
||||||
Timeout int
|
|
||||||
MaxIdleConnPerHost int
|
|
||||||
ReadConsistency string
|
|
||||||
IdleConnTimeout int
|
|
||||||
}
|
|
||||||
Pgvector struct {
|
|
||||||
DSN string `mapstructure:"dsn"`
|
|
||||||
Table string `mapstructure:"table"`
|
|
||||||
Dimensions int `mapstructure:"dimensions"`
|
|
||||||
MaxLifeTime int `mapstructure:"maxLifeTime"`
|
|
||||||
MaxOpenConn int `mapstructure:"maxOpenConn"`
|
|
||||||
MaxIdleConn int `mapstructure:"maxIdleConn"`
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Embedding struct {
|
Embedding struct {
|
||||||
Provider string
|
Provider string
|
||||||
BaseUrl string `mapstructure:"base_url"`
|
BaseUrl string `mapstructure:"base_url"`
|
||||||
@@ -84,15 +54,11 @@ type Config struct {
|
|||||||
Model string `mapstructure:"model"`
|
Model string `mapstructure:"model"`
|
||||||
Timeout int
|
Timeout int
|
||||||
}
|
}
|
||||||
VectorDB struct {
|
Faiss struct {
|
||||||
Url string
|
BaseUrl string `mapstructure:"base_url"`
|
||||||
Username string
|
SearchK int `mapstructure:"search_k"`
|
||||||
Pwd string
|
SimilarityThreshold float32 `mapstructure:"similarity_threshold"`
|
||||||
Database string
|
Timeout int
|
||||||
Timeout int
|
|
||||||
MaxIdleConnPerHost int
|
|
||||||
ReadConsistency string
|
|
||||||
IdleConnTimeout int
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -123,39 +89,6 @@ func GetConfig() *Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func normalizeConfig(conf *Config) {
|
func normalizeConfig(conf *Config) {
|
||||||
if conf.Vector.Provider == "" {
|
|
||||||
conf.Vector.Provider = "tencent"
|
|
||||||
}
|
|
||||||
if conf.Vector.Threshold == 0 {
|
|
||||||
conf.Vector.Threshold = 0.99
|
|
||||||
}
|
|
||||||
|
|
||||||
// Backfill the new vector.tencent block from the legacy vectorDB config.
|
|
||||||
if conf.Vector.Tencent.Url == "" {
|
|
||||||
conf.Vector.Tencent.Url = conf.VectorDB.Url
|
|
||||||
}
|
|
||||||
if conf.Vector.Tencent.Username == "" {
|
|
||||||
conf.Vector.Tencent.Username = conf.VectorDB.Username
|
|
||||||
}
|
|
||||||
if conf.Vector.Tencent.Pwd == "" {
|
|
||||||
conf.Vector.Tencent.Pwd = conf.VectorDB.Pwd
|
|
||||||
}
|
|
||||||
if conf.Vector.Tencent.Database == "" {
|
|
||||||
conf.Vector.Tencent.Database = conf.VectorDB.Database
|
|
||||||
}
|
|
||||||
if conf.Vector.Tencent.Timeout == 0 {
|
|
||||||
conf.Vector.Tencent.Timeout = conf.VectorDB.Timeout
|
|
||||||
}
|
|
||||||
if conf.Vector.Tencent.MaxIdleConnPerHost == 0 {
|
|
||||||
conf.Vector.Tencent.MaxIdleConnPerHost = conf.VectorDB.MaxIdleConnPerHost
|
|
||||||
}
|
|
||||||
if conf.Vector.Tencent.ReadConsistency == "" {
|
|
||||||
conf.Vector.Tencent.ReadConsistency = conf.VectorDB.ReadConsistency
|
|
||||||
}
|
|
||||||
if conf.Vector.Tencent.IdleConnTimeout == 0 {
|
|
||||||
conf.Vector.Tencent.IdleConnTimeout = conf.VectorDB.IdleConnTimeout
|
|
||||||
}
|
|
||||||
|
|
||||||
if conf.Embedding.Provider == "" {
|
if conf.Embedding.Provider == "" {
|
||||||
conf.Embedding.Provider = "openai-compatible"
|
conf.Embedding.Provider = "openai-compatible"
|
||||||
}
|
}
|
||||||
@@ -168,6 +101,18 @@ func normalizeConfig(conf *Config) {
|
|||||||
if conf.Embedding.Timeout == 0 {
|
if conf.Embedding.Timeout == 0 {
|
||||||
conf.Embedding.Timeout = 10
|
conf.Embedding.Timeout = 10
|
||||||
}
|
}
|
||||||
|
if conf.Faiss.BaseUrl == "" {
|
||||||
|
conf.Faiss.BaseUrl = "http://127.0.0.1:8451"
|
||||||
|
}
|
||||||
|
if conf.Faiss.SearchK == 0 {
|
||||||
|
conf.Faiss.SearchK = 1
|
||||||
|
}
|
||||||
|
if conf.Faiss.SimilarityThreshold == 0 {
|
||||||
|
conf.Faiss.SimilarityThreshold = 0.9
|
||||||
|
}
|
||||||
|
if conf.Faiss.Timeout == 0 {
|
||||||
|
conf.Faiss.Timeout = 10
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func applySecretEnvOverrides(conf *Config) {
|
func applySecretEnvOverrides(conf *Config) {
|
||||||
@@ -177,6 +122,9 @@ func applySecretEnvOverrides(conf *Config) {
|
|||||||
if v := os.Getenv("AI_CHAT_EMBEDDING_API_KEY"); v != "" {
|
if v := os.Getenv("AI_CHAT_EMBEDDING_API_KEY"); v != "" {
|
||||||
conf.Embedding.ApiKey = v
|
conf.Embedding.ApiKey = v
|
||||||
}
|
}
|
||||||
|
if v := os.Getenv("AI_CHAT_FAISS_BASE_URL"); v != "" {
|
||||||
|
conf.Faiss.BaseUrl = v
|
||||||
|
}
|
||||||
if v := os.Getenv("REDIS_PASSWORD"); v != "" {
|
if v := os.Getenv("REDIS_PASSWORD"); v != "" {
|
||||||
conf.Redis.Pwd = v
|
conf.Redis.Pwd = v
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,28 +0,0 @@
|
|||||||
package mysql
|
|
||||||
|
|
||||||
import (
|
|
||||||
"ai-chat-service/pkg/config"
|
|
||||||
"database/sql"
|
|
||||||
_ "github.com/go-sql-driver/mysql"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
var db *sql.DB
|
|
||||||
|
|
||||||
func InitMysql(cnf *config.Config) {
|
|
||||||
var err error
|
|
||||||
if cnf.Mysql.DSN == "" {
|
|
||||||
panic("数据库连接字符串不能为空")
|
|
||||||
}
|
|
||||||
db, err = sql.Open("mysql", cnf.Mysql.DSN)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
db.SetMaxOpenConns(cnf.Mysql.MaxOpenConn)
|
|
||||||
db.SetMaxIdleConns(cnf.Mysql.MaxIdleConn)
|
|
||||||
db.SetConnMaxLifetime(time.Second * time.Duration(cnf.Mysql.MaxLifeTime))
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetDB() *sql.DB {
|
|
||||||
return db
|
|
||||||
}
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
package vector
|
|
||||||
|
|
||||||
import (
|
|
||||||
"ai-chat-service/pkg/config"
|
|
||||||
"ai-chat-service/pkg/log"
|
|
||||||
"github.com/tencent/vectordatabase-sdk-go/tcvectordb"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
var vdb *tcvectordb.Client
|
|
||||||
|
|
||||||
func InitDB(config *config.Config) {
|
|
||||||
var defaultOption = &tcvectordb.ClientOption{
|
|
||||||
Timeout: time.Second * time.Duration(config.VectorDB.Timeout),
|
|
||||||
MaxIdldConnPerHost: config.VectorDB.MaxIdleConnPerHost,
|
|
||||||
IdleConnTimeout: time.Second * time.Duration(config.VectorDB.IdleConnTimeout),
|
|
||||||
ReadConsistency: tcvectordb.ReadConsistency(config.VectorDB.ReadConsistency),
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
vdb, err = tcvectordb.NewClient(config.VectorDB.Url, config.VectorDB.Username, config.VectorDB.Pwd, defaultOption)
|
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetVdb() *tcvectordb.Client {
|
|
||||||
return vdb
|
|
||||||
}
|
|
||||||
@@ -19,7 +19,6 @@ type openAICompatibleEmbedder struct {
|
|||||||
baseURL string
|
baseURL string
|
||||||
apiKey string
|
apiKey string
|
||||||
model string
|
model string
|
||||||
dimensions int
|
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -41,7 +40,6 @@ func NewEmbedder(cnf *config.Config) (Embedder, error) {
|
|||||||
baseURL: strings.TrimRight(cnf.Embedding.BaseUrl, "/"),
|
baseURL: strings.TrimRight(cnf.Embedding.BaseUrl, "/"),
|
||||||
apiKey: cnf.Embedding.ApiKey,
|
apiKey: cnf.Embedding.ApiKey,
|
||||||
model: cnf.Embedding.Model,
|
model: cnf.Embedding.Model,
|
||||||
dimensions: cnf.Vector.Pgvector.Dimensions,
|
|
||||||
httpClient: &http.Client{Timeout: time.Duration(cnf.Embedding.Timeout) * time.Second},
|
httpClient: &http.Client{Timeout: time.Duration(cnf.Embedding.Timeout) * time.Second},
|
||||||
}, nil
|
}, nil
|
||||||
default:
|
default:
|
||||||
@@ -108,8 +106,5 @@ func (e *openAICompatibleEmbedder) Embed(ctx context.Context, text string) ([]fl
|
|||||||
if len(result.Data) == 0 || len(result.Data[0].Embedding) == 0 {
|
if len(result.Data) == 0 || len(result.Data[0].Embedding) == 0 {
|
||||||
return nil, fmt.Errorf("embedding response is empty")
|
return nil, fmt.Errorf("embedding response is empty")
|
||||||
}
|
}
|
||||||
if e.dimensions > 0 && len(result.Data[0].Embedding) != e.dimensions {
|
|
||||||
return nil, fmt.Errorf("embedding dimension mismatch: got=%d want=%d", len(result.Data[0].Embedding), e.dimensions)
|
|
||||||
}
|
|
||||||
return result.Data[0].Embedding, nil
|
return result.Data[0].Embedding, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,35 +1,4 @@
|
|||||||
services:
|
services:
|
||||||
mysql:
|
|
||||||
image: mysql:8.0
|
|
||||||
container_name: ai-chat-mysql
|
|
||||||
restart: unless-stopped
|
|
||||||
environment:
|
|
||||||
MYSQL_ROOT_PASSWORD: root
|
|
||||||
command:
|
|
||||||
- --default-authentication-plugin=mysql_native_password
|
|
||||||
volumes:
|
|
||||||
- /data/mysql:/var/lib/mysql
|
|
||||||
- /home/lian/share/aichat/init/create_db.sql:/docker-entrypoint-initdb.d/create_db.sql:ro
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD", "mysqladmin", "ping", "-h", "127.0.0.1", "-proot"]
|
|
||||||
interval: 15s
|
|
||||||
timeout: 5s
|
|
||||||
retries: 10
|
|
||||||
|
|
||||||
pgvector:
|
|
||||||
image: pgvector/pgvector:pg16
|
|
||||||
container_name: ai-chat-pgvector
|
|
||||||
restart: unless-stopped
|
|
||||||
environment:
|
|
||||||
POSTGRES_DB: ai_chat
|
|
||||||
POSTGRES_USER: postgres
|
|
||||||
POSTGRES_PASSWORD: postgres
|
|
||||||
ports:
|
|
||||||
- "15432:5432"
|
|
||||||
volumes:
|
|
||||||
- /data/pgvector:/var/lib/postgresql/data
|
|
||||||
- /home/lian/share/aichat/init/pgvector-init.sql:/docker-entrypoint-initdb.d/pgvector-init.sql:ro
|
|
||||||
|
|
||||||
tokenizer:
|
tokenizer:
|
||||||
build:
|
build:
|
||||||
context: ../tokenizer
|
context: ../tokenizer
|
||||||
@@ -83,11 +52,9 @@ services:
|
|||||||
ports:
|
ports:
|
||||||
- "50055:50055"
|
- "50055:50055"
|
||||||
depends_on:
|
depends_on:
|
||||||
- mysql
|
|
||||||
- tokenizer
|
- tokenizer
|
||||||
- sensitive-filter
|
- sensitive-filter
|
||||||
- keywords-filter
|
- keywords-filter
|
||||||
- pgvector
|
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "grpc_health_probe", "-addr=:50055"]
|
test: ["CMD", "grpc_health_probe", "-addr=:50055"]
|
||||||
interval: 15s
|
interval: 15s
|
||||||
|
|||||||
@@ -16,17 +16,10 @@ chat:
|
|||||||
frequency_penalty: 0
|
frequency_penalty: 0
|
||||||
bot_desc: "你是一个AI助手,我需要你模拟一名资深的软件工程师来回答我的问题"
|
bot_desc: "你是一个AI助手,我需要你模拟一名资深的软件工程师来回答我的问题"
|
||||||
min_response_tokens: 600
|
min_response_tokens: 600
|
||||||
context_ttl: 1800
|
|
||||||
context_len: 4
|
|
||||||
redis:
|
redis:
|
||||||
host: "host.docker.internal"
|
host: "host.docker.internal"
|
||||||
port: 8888
|
port: 8888
|
||||||
pwd: "123456"
|
pwd: "123456"
|
||||||
mysql:
|
|
||||||
dsn: "root:root@tcp(mysql:3306)/ai_chat?collation=utf8mb4_unicode_ci&charset=utf8mb4"
|
|
||||||
maxLifeTime: 3600
|
|
||||||
maxOpenConn: 10
|
|
||||||
maxIdleConn: 10
|
|
||||||
dependOn:
|
dependOn:
|
||||||
sensitive:
|
sensitive:
|
||||||
address: "sensitive-filter:50053"
|
address: "sensitive-filter:50053"
|
||||||
@@ -36,19 +29,14 @@ dependOn:
|
|||||||
accessToken: "ang1chubdev1ozhome256487d22sapguuv1ozhom"
|
accessToken: "ang1chubdev1ozhome256487d22sapguuv1ozhom"
|
||||||
tokenizer:
|
tokenizer:
|
||||||
address: "http://tokenizer:3002"
|
address: "http://tokenizer:3002"
|
||||||
vector:
|
|
||||||
provider: "pgvector"
|
|
||||||
threshold: 0.99
|
|
||||||
pgvector:
|
|
||||||
dsn: "postgres://postgres:postgres@pgvector:5432/ai_chat?sslmode=disable"
|
|
||||||
table: "chat_record_vectors"
|
|
||||||
dimensions: 1024
|
|
||||||
maxLifeTime: 3600
|
|
||||||
maxOpenConn: 10
|
|
||||||
maxIdleConn: 10
|
|
||||||
embedding:
|
embedding:
|
||||||
provider: "openai-compatible"
|
provider: "openai-compatible"
|
||||||
base_url: "https://open.bigmodel.cn/api/paas/v4"
|
base_url: "https://open.bigmodel.cn/api/paas/v4"
|
||||||
api_key: "__SET_FROM_ENV__"
|
api_key: "__SET_FROM_ENV__"
|
||||||
model: "embedding-2"
|
model: "embedding-2"
|
||||||
timeout: 10
|
timeout: 10
|
||||||
|
faiss:
|
||||||
|
base_url: "http://host.docker.internal:8451"
|
||||||
|
search_k: 1
|
||||||
|
similarity_threshold: 0.9
|
||||||
|
timeout: 10
|
||||||
|
|||||||
@@ -50,8 +50,8 @@ export default {
|
|||||||
showRawText: 'Show as raw text',
|
showRawText: 'Show as raw text',
|
||||||
sourceSemantic: 'Semantic Match',
|
sourceSemantic: 'Semantic Match',
|
||||||
sourceLlm: 'LLM Output',
|
sourceLlm: 'LLM Output',
|
||||||
promptTokens: 'Prompt {count} tokens',
|
inputTokens: 'Input {count}',
|
||||||
completionTokens: 'Completion {count} tokens',
|
outputTokens: 'Output {count}',
|
||||||
sessionTokens: 'Session {count} tokens',
|
sessionTokens: 'Session {count} tokens',
|
||||||
},
|
},
|
||||||
setting: {
|
setting: {
|
||||||
|
|||||||
@@ -50,8 +50,8 @@ export default {
|
|||||||
showRawText: '显示原文',
|
showRawText: '显示原文',
|
||||||
sourceSemantic: '语义匹配',
|
sourceSemantic: '语义匹配',
|
||||||
sourceLlm: '大模型输出',
|
sourceLlm: '大模型输出',
|
||||||
promptTokens: '问题 {count} tokens',
|
inputTokens: 'Input {count}',
|
||||||
completionTokens: '回答 {count} tokens',
|
outputTokens: 'Output {count}',
|
||||||
sessionTokens: '本轮消耗 {count} tokens',
|
sessionTokens: '本轮消耗 {count} tokens',
|
||||||
},
|
},
|
||||||
setting: {
|
setting: {
|
||||||
|
|||||||
@@ -50,8 +50,8 @@ export default {
|
|||||||
showRawText: '顯示原文',
|
showRawText: '顯示原文',
|
||||||
sourceSemantic: '語義匹配',
|
sourceSemantic: '語義匹配',
|
||||||
sourceLlm: '大模型輸出',
|
sourceLlm: '大模型輸出',
|
||||||
promptTokens: '問題 {count} tokens',
|
inputTokens: 'Input {count}',
|
||||||
completionTokens: '回答 {count} tokens',
|
outputTokens: 'Output {count}',
|
||||||
sessionTokens: '本輪消耗 {count} tokens',
|
sessionTokens: '本輪消耗 {count} tokens',
|
||||||
},
|
},
|
||||||
setting: {
|
setting: {
|
||||||
|
|||||||
@@ -62,15 +62,20 @@ const sourceClass = computed(() => {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
const usageLabel = computed(() => {
|
const inputUsageLabel = computed(() => {
|
||||||
const usage = props.messageMeta?.usage
|
const usage = props.messageMeta?.usage
|
||||||
if (!usage)
|
if (!usage || props.inversion)
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
if (props.inversion)
|
return t('chat.inputTokens', { count: usage.prompt_tokens })
|
||||||
return t('chat.promptTokens', { count: usage.prompt_tokens })
|
})
|
||||||
|
|
||||||
return t('chat.completionTokens', { count: usage.completion_tokens })
|
const outputUsageLabel = computed(() => {
|
||||||
|
const usage = props.messageMeta?.usage
|
||||||
|
if (!usage || props.inversion)
|
||||||
|
return ''
|
||||||
|
|
||||||
|
return t('chat.outputTokens', { count: usage.completion_tokens })
|
||||||
})
|
})
|
||||||
|
|
||||||
const options = computed(() => {
|
const options = computed(() => {
|
||||||
@@ -130,21 +135,35 @@ function handleRegenerate() {
|
|||||||
<AvatarComponent :image="inversion" />
|
<AvatarComponent :image="inversion" />
|
||||||
</div>
|
</div>
|
||||||
<div class="overflow-hidden text-sm " :class="[inversion ? 'items-end' : 'items-start']">
|
<div class="overflow-hidden text-sm " :class="[inversion ? 'items-end' : 'items-start']">
|
||||||
<div class="flex flex-wrap items-center gap-2 text-xs" :class="[inversion ? 'justify-end' : 'justify-start']">
|
<div class="flex flex-col gap-2" :class="[inversion ? 'items-end' : 'items-start']">
|
||||||
<span class="font-medium text-[#9aa4af] dark:text-neutral-500">{{ dateTime }}</span>
|
<span class="text-xs font-medium text-[#9aa4af] dark:text-neutral-500">{{ dateTime }}</span>
|
||||||
<span
|
<div
|
||||||
v-if="sourceLabel"
|
v-if="sourceLabel || inputUsageLabel || outputUsageLabel"
|
||||||
class="rounded-full border px-2.5 py-1 text-[11px] font-medium leading-none"
|
class="flex flex-wrap items-center gap-2 rounded-2xl border border-[#e6edf3] bg-white/85 px-2.5 py-2 shadow-[0_10px_25px_rgba(15,23,42,0.06)] backdrop-blur-sm dark:border-neutral-800 dark:bg-[#14161a]/90"
|
||||||
:class="sourceClass"
|
|
||||||
>
|
>
|
||||||
{{ sourceLabel }}
|
<div
|
||||||
</span>
|
v-if="sourceLabel"
|
||||||
<span
|
class="inline-flex items-center gap-1.5 rounded-full border px-2.5 py-1 text-[11px] font-semibold leading-none"
|
||||||
v-if="usageLabel"
|
:class="sourceClass"
|
||||||
class="rounded-full border border-violet-200 bg-violet-50 px-2.5 py-1 text-[11px] font-medium leading-none text-violet-700 dark:border-violet-900/60 dark:bg-violet-950/40 dark:text-violet-300"
|
>
|
||||||
>
|
<SvgIcon icon="ri:radar-line" />
|
||||||
{{ usageLabel }}
|
<span>{{ sourceLabel }}</span>
|
||||||
</span>
|
</div>
|
||||||
|
<div
|
||||||
|
v-if="inputUsageLabel"
|
||||||
|
class="inline-flex items-center gap-1.5 rounded-full border border-violet-200 bg-violet-50 px-2.5 py-1 text-[11px] font-semibold leading-none text-violet-700 dark:border-violet-900/60 dark:bg-violet-950/40 dark:text-violet-300"
|
||||||
|
>
|
||||||
|
<SvgIcon icon="ri:login-circle-line" />
|
||||||
|
<span>{{ inputUsageLabel }}</span>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
v-if="outputUsageLabel"
|
||||||
|
class="inline-flex items-center gap-1.5 rounded-full border border-fuchsia-200 bg-fuchsia-50 px-2.5 py-1 text-[11px] font-semibold leading-none text-fuchsia-700 dark:border-fuchsia-900/60 dark:bg-fuchsia-950/40 dark:text-fuchsia-300"
|
||||||
|
>
|
||||||
|
<SvgIcon icon="ri:logout-circle-r-line" />
|
||||||
|
<span>{{ outputUsageLabel }}</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div
|
<div
|
||||||
class="flex items-end gap-1 mt-2"
|
class="flex items-end gap-1 mt-2"
|
||||||
|
|||||||
@@ -95,45 +95,7 @@ function normalizeResponseMeta(data: Chat.ConversationResponse): {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function buildPromptUsage(usage: Chat.TokenUsage): Chat.TokenUsage {
|
function updateMessageMeta(answerIndex: number, usage: Chat.TokenUsage, source: Chat.ReplySource | null, tokenUsed?: boolean) {
|
||||||
return {
|
|
||||||
prompt_tokens: usage.prompt_tokens,
|
|
||||||
completion_tokens: 0,
|
|
||||||
total_tokens: usage.prompt_tokens,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function buildCompletionUsage(usage: Chat.TokenUsage): Chat.TokenUsage {
|
|
||||||
return {
|
|
||||||
prompt_tokens: 0,
|
|
||||||
completion_tokens: usage.completion_tokens,
|
|
||||||
total_tokens: usage.completion_tokens,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function findQuestionIndex(answerIndex: number) {
|
|
||||||
for (let current = answerIndex - 1; current >= 0; current -= 1) {
|
|
||||||
if (dataSources.value[current]?.inversion)
|
|
||||||
return current
|
|
||||||
}
|
|
||||||
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
|
|
||||||
function updateMessageMeta(questionIndex: number, answerIndex: number, usage: Chat.TokenUsage, source: Chat.ReplySource | null, tokenUsed?: boolean) {
|
|
||||||
if (questionIndex >= 0) {
|
|
||||||
updateChatSome(
|
|
||||||
+uuid,
|
|
||||||
questionIndex,
|
|
||||||
{
|
|
||||||
messageMeta: {
|
|
||||||
tokenUsed,
|
|
||||||
usage: buildPromptUsage(usage),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
updateChatSome(
|
updateChatSome(
|
||||||
+uuid,
|
+uuid,
|
||||||
answerIndex,
|
answerIndex,
|
||||||
@@ -141,7 +103,7 @@ function updateMessageMeta(questionIndex: number, answerIndex: number, usage: Ch
|
|||||||
messageMeta: {
|
messageMeta: {
|
||||||
source,
|
source,
|
||||||
tokenUsed,
|
tokenUsed,
|
||||||
usage: buildCompletionUsage(usage),
|
usage,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -182,8 +144,6 @@ async function onConversation() {
|
|||||||
)
|
)
|
||||||
scrollToBottom()
|
scrollToBottom()
|
||||||
|
|
||||||
const questionIndex = dataSources.value.length - 1
|
|
||||||
|
|
||||||
loading.value = true
|
loading.value = true
|
||||||
prompt.value = ''
|
prompt.value = ''
|
||||||
|
|
||||||
@@ -250,7 +210,7 @@ async function onConversation() {
|
|||||||
)
|
)
|
||||||
|
|
||||||
if (responseMeta.source || responseMeta.usage || responseMeta.tokenUsed !== undefined)
|
if (responseMeta.source || responseMeta.usage || responseMeta.tokenUsed !== undefined)
|
||||||
updateMessageMeta(questionIndex, answerIndex, nextUsage, responseMeta.source, responseMeta.tokenUsed)
|
updateMessageMeta(answerIndex, nextUsage, responseMeta.source, responseMeta.tokenUsed)
|
||||||
|
|
||||||
if (responseMeta.usage && !usageApplied) {
|
if (responseMeta.usage && !usageApplied) {
|
||||||
accumulatedUsage = nextUsage
|
accumulatedUsage = nextUsage
|
||||||
@@ -343,9 +303,7 @@ async function onRegenerate(index: number) {
|
|||||||
options = { ...requestOptions.options }
|
options = { ...requestOptions.options }
|
||||||
|
|
||||||
loading.value = true
|
loading.value = true
|
||||||
|
|
||||||
let accumulatedUsage = createEmptyUsage()
|
let accumulatedUsage = createEmptyUsage()
|
||||||
const questionIndex = findQuestionIndex(index)
|
|
||||||
|
|
||||||
updateChat(
|
updateChat(
|
||||||
+uuid,
|
+uuid,
|
||||||
@@ -401,7 +359,7 @@ async function onRegenerate(index: number) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
if (responseMeta.source || responseMeta.usage || responseMeta.tokenUsed !== undefined)
|
if (responseMeta.source || responseMeta.usage || responseMeta.tokenUsed !== undefined)
|
||||||
updateMessageMeta(questionIndex, index, nextUsage, responseMeta.source, responseMeta.tokenUsed)
|
updateMessageMeta(index, nextUsage, responseMeta.source, responseMeta.tokenUsed)
|
||||||
|
|
||||||
if (responseMeta.usage && !usageApplied) {
|
if (responseMeta.usage && !usageApplied) {
|
||||||
accumulatedUsage = nextUsage
|
accumulatedUsage = nextUsage
|
||||||
|
|||||||
Reference in New Issue
Block a user