service 修改 Redis 存储 KV
This commit is contained in:
@@ -1,20 +1,18 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
chat_context "ai-chat-service/chat-server/chat-context"
|
||||
"ai-chat-service/chat-server/data"
|
||||
metrics_bus "ai-chat-service/chat-server/metrics-bus"
|
||||
vector_data "ai-chat-service/chat-server/vector-data"
|
||||
"ai-chat-service/pkg/config"
|
||||
"ai-chat-service/pkg/log"
|
||||
"ai-chat-service/proto"
|
||||
"ai-chat-service/services/embedding"
|
||||
"ai-chat-service/services/faiss"
|
||||
"ai-chat-service/services/tokenizer"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/jsonpb"
|
||||
"github.com/google/uuid"
|
||||
@@ -26,218 +24,127 @@ type chatService struct {
|
||||
config *config.Config
|
||||
log log.ILogger
|
||||
data data.IChatRecordsData
|
||||
vectorData vector_data.IChatRecordsData
|
||||
embedder embedding.Embedder
|
||||
faiss faiss.Client
|
||||
busMetrics *metrics_bus.BusMetrics
|
||||
}
|
||||
|
||||
func NewChatService(data data.IChatRecordsData, vectorData vector_data.IChatRecordsData, config *config.Config, log log.ILogger, busMetrics *metrics_bus.BusMetrics) proto.ChatServer {
|
||||
func NewChatService(data data.IChatRecordsData, embedder embedding.Embedder, faissClient faiss.Client, config *config.Config, log log.ILogger, busMetrics *metrics_bus.BusMetrics) proto.ChatServer {
|
||||
return &chatService{
|
||||
config: config,
|
||||
log: log,
|
||||
data: data,
|
||||
vectorData: vectorData,
|
||||
embedder: embedder,
|
||||
faiss: faissClient,
|
||||
busMetrics: busMetrics,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *chatService) ChatCompletion(ctx context.Context, in *proto.ChatCompletionRequest) (*proto.ChatCompletionResponse, error) {
|
||||
redisContextCache := chat_context.NewRedisCache()
|
||||
defer redisContextCache.Close()
|
||||
app := s.newApp(in)
|
||||
|
||||
app := s.newApp(in, redisContextCache)
|
||||
//敏感词过滤
|
||||
ok, msg, err := app.sensitive(in)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
res := app.buildChatCompletionResponse(msg)
|
||||
return res, nil
|
||||
}
|
||||
|
||||
//关键词提取
|
||||
keywords := app.keywords(in)
|
||||
if len(keywords) > 0 {
|
||||
idStr, score, err := s.vectorData.QueryData(context.Background(), map[string][]string{"keywords": {strings.Join(keywords, ",")}})
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
} else if score > s.config.Vector.Threshold {
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
} else {
|
||||
record, err := s.data.GetById(id)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
} else {
|
||||
res := app.buildChatCompletionResponse(record.AIMsg)
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
client := app.getOpenaiClient()
|
||||
req, tokens, currTokens, currMessage, err := app.buildChatCompletionRequest(in, false)
|
||||
resp, err := client.CreateChatCompletion(ctx, req)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
res := &proto.ChatCompletionResponse{}
|
||||
bytes, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
err = jsonpb.UnmarshalString(string(bytes), res)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
go func() {
|
||||
reqContext := &chat_context.ChatMessage{
|
||||
ID: in.Id,
|
||||
PID: in.Pid,
|
||||
Message: currMessage,
|
||||
Tokens: currTokens,
|
||||
}
|
||||
err := app.saveContext(reqContext)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return
|
||||
}
|
||||
resContext := &chat_context.ChatMessage{
|
||||
ID: resp.ID,
|
||||
PID: reqContext.ID,
|
||||
Message: resp.Choices[0].Message,
|
||||
Tokens: resp.Usage.CompletionTokens,
|
||||
}
|
||||
err = app.saveContext(resContext)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
records := &data.ChatRecord{
|
||||
UserMsg: in.Message,
|
||||
UserMsgTokens: currTokens,
|
||||
UserMsgKeywords: keywords,
|
||||
AIMsg: resp.Choices[0].Message.Content,
|
||||
AIMsgTokens: resp.Usage.CompletionTokens,
|
||||
ReqTokens: tokens,
|
||||
CreateAt: time.Now().Unix(),
|
||||
}
|
||||
err := s.data.Add(records)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return
|
||||
}
|
||||
//保存到向量数据库
|
||||
if len(keywords) > 0 {
|
||||
list := []*vector_data.ChatRecord{
|
||||
{
|
||||
ID: strconv.FormatInt(records.ID, 10),
|
||||
KVs: map[string]string{
|
||||
"keywords": strings.Join(keywords, ","),
|
||||
},
|
||||
},
|
||||
}
|
||||
err = s.vectorData.UpsertData(context.Background(), list)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return res, err
|
||||
}
|
||||
func (s *chatService) ChatCompletionStream(in *proto.ChatCompletionRequest, stream proto.Chat_ChatCompletionStreamServer) error {
|
||||
redisContextCache := chat_context.NewRedisCache()
|
||||
defer redisContextCache.Close()
|
||||
|
||||
app := s.newApp(in, redisContextCache)
|
||||
//敏感词过滤
|
||||
ok, msg, err := app.sensitive(in)
|
||||
if err != nil {
|
||||
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
s.busMetrics.SensitiveQuestionsTotalCounter.Inc()
|
||||
return app.buildChatCompletionResponse(msg), nil
|
||||
}
|
||||
|
||||
keywords := app.keywords(in)
|
||||
if len(keywords) > 0 {
|
||||
s.busMetrics.KeywordsQuestionsTotalCounter.Inc()
|
||||
}
|
||||
|
||||
req, _, _, _, err := app.buildChatCompletionRequest(in, false)
|
||||
if err != nil {
|
||||
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
questionEmbedding, cachedRecord := s.searchCachedAnswer(ctx, in.Message)
|
||||
if cachedRecord != nil {
|
||||
return app.buildChatCompletionResponse(cachedRecord.Answer), nil
|
||||
}
|
||||
|
||||
client := app.getOpenaiClient()
|
||||
resp, err := client.CreateChatCompletion(ctx, req)
|
||||
if err != nil {
|
||||
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||
s.log.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := &proto.ChatCompletionResponse{}
|
||||
bytes, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||
return nil, err
|
||||
}
|
||||
if err = jsonpb.UnmarshalString(string(bytes), res); err != nil {
|
||||
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(resp.Choices) > 0 {
|
||||
if err = s.persistQA(ctx, questionEmbedding, in.Message, resp.Choices[0].Message.Content); err != nil {
|
||||
s.log.Error(err)
|
||||
} else {
|
||||
s.busMetrics.QuestionsTotalCounter.Inc()
|
||||
}
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (s *chatService) ChatCompletionStream(in *proto.ChatCompletionRequest, stream proto.Chat_ChatCompletionStreamServer) error {
|
||||
app := s.newApp(in)
|
||||
|
||||
ok, msg, err := app.sensitive(in)
|
||||
if err != nil {
|
||||
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
s.busMetrics.SensitiveQuestionsTotalCounter.Inc()
|
||||
resId := uuid.New().String()
|
||||
startRes := app.buildChatCompletionStreamResponse(resId, "", "")
|
||||
endRes := app.buildChatCompletionStreamResponse(resId, "", "stop")
|
||||
err = stream.Send(startRes)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
resID := uuid.New().String()
|
||||
if err = stream.Send(app.buildChatCompletionStreamResponse(resID, "", "")); err != nil {
|
||||
return err
|
||||
}
|
||||
resList := app.buildChatCompletionStreamResponseList(resId, msg)
|
||||
for _, res := range resList {
|
||||
err = stream.Send(res)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
for _, res := range app.buildChatCompletionStreamResponseList(resID, msg) {
|
||||
if err = stream.Send(res); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err = stream.Send(endRes)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return stream.Send(app.buildChatCompletionStreamResponse(resID, "", "stop"))
|
||||
}
|
||||
|
||||
//关键词提取
|
||||
keywords := app.keywords(in)
|
||||
|
||||
if len(keywords) > 0 {
|
||||
s.busMetrics.KeywordsQuestionsTotalCounter.Inc()
|
||||
idStr, score, err := s.vectorData.QueryData(context.Background(), map[string][]string{"keywords": {strings.Join(keywords, ",")}})
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
} else if score > s.config.Vector.Threshold {
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
} else {
|
||||
record, err := s.data.GetById(id)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
} else {
|
||||
resId := uuid.New().String()
|
||||
startRes := app.buildChatCompletionStreamResponse(resId, "", "")
|
||||
endRes := app.buildChatCompletionStreamResponse(resId, "", "stop")
|
||||
err = stream.Send(startRes)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return err
|
||||
}
|
||||
resList := app.buildChatCompletionStreamResponseList(resId, record.AIMsg)
|
||||
for _, res := range resList {
|
||||
err = stream.Send(res)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
err = stream.Send(endRes)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
req, _, _, _, err := app.buildChatCompletionRequest(in, true)
|
||||
if err != nil {
|
||||
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||
return err
|
||||
}
|
||||
|
||||
questionEmbedding, cachedRecord := s.searchCachedAnswer(stream.Context(), in.Message)
|
||||
if cachedRecord != nil {
|
||||
if err = stream.Send(app.buildChatCompletionStreamResponse(cachedRecord.ID, "", "")); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, res := range app.buildChatCompletionStreamResponseList(cachedRecord.ID, cachedRecord.Answer) {
|
||||
if err = stream.Send(res); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return stream.Send(app.buildChatCompletionStreamResponse(cachedRecord.ID, "", "stop"))
|
||||
}
|
||||
|
||||
client := app.getOpenaiClient()
|
||||
req, tokens, currTokens, currMessage, err := app.buildChatCompletionRequest(in, false)
|
||||
chatStream, err := client.CreateChatCompletionStream(stream.Context(), req)
|
||||
if err != nil {
|
||||
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||
@@ -245,109 +152,106 @@ func (s *chatService) ChatCompletionStream(in *proto.ChatCompletionRequest, stre
|
||||
return err
|
||||
}
|
||||
defer chatStream.Close()
|
||||
|
||||
completionContent := ""
|
||||
resultID := ""
|
||||
for {
|
||||
resp, err := chatStream.Recv()
|
||||
if err != nil && err != io.EOF {
|
||||
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||
s.log.Error(err)
|
||||
return err
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if resultID == "" {
|
||||
resultID = resp.ID
|
||||
}
|
||||
completionContent += resp.Choices[0].Delta.Content
|
||||
|
||||
res := &proto.ChatCompletionStreamResponse{}
|
||||
bytes, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||
return err
|
||||
}
|
||||
err = jsonpb.UnmarshalString(string(bytes), res)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
if err = jsonpb.UnmarshalString(string(bytes), res); err != nil {
|
||||
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||
return err
|
||||
}
|
||||
err = stream.Send(res)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
if err = stream.Send(res); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
model := s.config.Chat.Model
|
||||
if in.ChatParam != nil && in.ChatParam.Model != "" {
|
||||
model = in.ChatParam.Model
|
||||
}
|
||||
resultMessage := openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
Content: completionContent,
|
||||
}
|
||||
model := s.config.Chat.Model
|
||||
if in.ChatParam != nil && in.ChatParam.Model != "" {
|
||||
model = in.ChatParam.Model
|
||||
}
|
||||
resultTokens, err := tokenizer.GetTokens(&resultMessage, model)
|
||||
if err != nil {
|
||||
if _, err = tokenizer.GetTokens(&resultMessage, model); err != nil {
|
||||
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||
s.log.Error(err)
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
reqContext := &chat_context.ChatMessage{
|
||||
ID: in.Id,
|
||||
PID: in.Pid,
|
||||
Message: currMessage,
|
||||
Tokens: currTokens,
|
||||
}
|
||||
err := app.saveContext(reqContext)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return
|
||||
}
|
||||
resContext := &chat_context.ChatMessage{
|
||||
ID: resultID,
|
||||
PID: reqContext.ID,
|
||||
Message: resultMessage,
|
||||
Tokens: resultTokens,
|
||||
}
|
||||
err = app.saveContext(resContext)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
if err = s.persistQA(stream.Context(), questionEmbedding, in.Message, completionContent); err != nil {
|
||||
s.log.Error(err)
|
||||
} else {
|
||||
s.busMetrics.QuestionsTotalCounter.Inc()
|
||||
records := &data.ChatRecord{
|
||||
UserMsg: in.Message,
|
||||
UserMsgTokens: currTokens,
|
||||
UserMsgKeywords: keywords,
|
||||
AIMsg: completionContent,
|
||||
AIMsgTokens: resultTokens,
|
||||
ReqTokens: tokens,
|
||||
CreateAt: time.Now().Unix(),
|
||||
}
|
||||
err := s.data.Add(records)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return
|
||||
}
|
||||
//保存到向量数据库
|
||||
if len(keywords) > 0 {
|
||||
list := []*vector_data.ChatRecord{
|
||||
{
|
||||
ID: strconv.FormatInt(records.ID, 10),
|
||||
KVs: map[string]string{
|
||||
"keywords": strings.Join(keywords, ","),
|
||||
},
|
||||
},
|
||||
}
|
||||
err = s.vectorData.UpsertData(context.Background(), list)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *chatService) searchCachedAnswer(ctx context.Context, question string) ([]float32, *data.ChatRecord) {
|
||||
embeddingVector, err := s.embedder.Embed(ctx, question)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
searchRes, err := s.faiss.Search(ctx, embeddingVector, s.config.Faiss.SearchK)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return embeddingVector, nil
|
||||
}
|
||||
if searchRes == nil || len(searchRes.IDs) == 0 || len(searchRes.SimilarityScores) == 0 {
|
||||
return embeddingVector, nil
|
||||
}
|
||||
limit := len(searchRes.IDs)
|
||||
if len(searchRes.SimilarityScores) < limit {
|
||||
limit = len(searchRes.SimilarityScores)
|
||||
}
|
||||
for i := 0; i < limit; i++ {
|
||||
if searchRes.IDs[i] < 0 || searchRes.SimilarityScores[i] < s.config.Faiss.SimilarityThreshold {
|
||||
continue
|
||||
}
|
||||
record, err := s.data.GetById(strconv.FormatInt(searchRes.IDs[i], 10))
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return embeddingVector, nil
|
||||
}
|
||||
if record != nil {
|
||||
return embeddingVector, record
|
||||
}
|
||||
}
|
||||
return embeddingVector, nil
|
||||
}
|
||||
|
||||
func (s *chatService) persistQA(ctx context.Context, questionEmbedding []float32, question, answer string) error {
|
||||
if len(questionEmbedding) == 0 {
|
||||
vector, err := s.embedder.Embed(ctx, question)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
questionEmbedding = vector
|
||||
}
|
||||
id, err := s.faiss.Insert(ctx, questionEmbedding)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.data.Add(&data.ChatRecord{
|
||||
ID: id,
|
||||
Question: question,
|
||||
Answer: answer,
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user