service 修改 Redis 存储 KV

This commit is contained in:
2026-04-10 11:12:10 +00:00
parent c888ca8844
commit bc82e3e708
25 changed files with 322 additions and 3666 deletions

View File

@@ -1,20 +1,18 @@
package server
import (
chat_context "ai-chat-service/chat-server/chat-context"
"ai-chat-service/chat-server/data"
metrics_bus "ai-chat-service/chat-server/metrics-bus"
vector_data "ai-chat-service/chat-server/vector-data"
"ai-chat-service/pkg/config"
"ai-chat-service/pkg/log"
"ai-chat-service/proto"
"ai-chat-service/services/embedding"
"ai-chat-service/services/faiss"
"ai-chat-service/services/tokenizer"
"context"
"encoding/json"
"io"
"strconv"
"strings"
"time"
"github.com/golang/protobuf/jsonpb"
"github.com/google/uuid"
@@ -26,218 +24,127 @@ type chatService struct {
config *config.Config
log log.ILogger
data data.IChatRecordsData
vectorData vector_data.IChatRecordsData
embedder embedding.Embedder
faiss faiss.Client
busMetrics *metrics_bus.BusMetrics
}
func NewChatService(data data.IChatRecordsData, vectorData vector_data.IChatRecordsData, config *config.Config, log log.ILogger, busMetrics *metrics_bus.BusMetrics) proto.ChatServer {
func NewChatService(data data.IChatRecordsData, embedder embedding.Embedder, faissClient faiss.Client, config *config.Config, log log.ILogger, busMetrics *metrics_bus.BusMetrics) proto.ChatServer {
return &chatService{
config: config,
log: log,
data: data,
vectorData: vectorData,
embedder: embedder,
faiss: faissClient,
busMetrics: busMetrics,
}
}
func (s *chatService) ChatCompletion(ctx context.Context, in *proto.ChatCompletionRequest) (*proto.ChatCompletionResponse, error) {
redisContextCache := chat_context.NewRedisCache()
defer redisContextCache.Close()
app := s.newApp(in)
app := s.newApp(in, redisContextCache)
//敏感词过滤
ok, msg, err := app.sensitive(in)
if err != nil {
s.log.Error(err)
return nil, err
}
if !ok {
res := app.buildChatCompletionResponse(msg)
return res, nil
}
//关键词提取
keywords := app.keywords(in)
if len(keywords) > 0 {
idStr, score, err := s.vectorData.QueryData(context.Background(), map[string][]string{"keywords": {strings.Join(keywords, ",")}})
if err != nil {
s.log.Error(err)
} else if score > s.config.Vector.Threshold {
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
s.log.Error(err)
} else {
record, err := s.data.GetById(id)
if err != nil {
s.log.Error(err)
} else {
res := app.buildChatCompletionResponse(record.AIMsg)
return res, nil
}
}
}
}
client := app.getOpenaiClient()
req, tokens, currTokens, currMessage, err := app.buildChatCompletionRequest(in, false)
resp, err := client.CreateChatCompletion(ctx, req)
if err != nil {
s.log.Error(err)
return nil, err
}
res := &proto.ChatCompletionResponse{}
bytes, err := json.Marshal(resp)
if err != nil {
s.log.Error(err)
return nil, err
}
err = jsonpb.UnmarshalString(string(bytes), res)
if err != nil {
s.log.Error(err)
return nil, err
}
go func() {
reqContext := &chat_context.ChatMessage{
ID: in.Id,
PID: in.Pid,
Message: currMessage,
Tokens: currTokens,
}
err := app.saveContext(reqContext)
if err != nil {
s.log.Error(err)
return
}
resContext := &chat_context.ChatMessage{
ID: resp.ID,
PID: reqContext.ID,
Message: resp.Choices[0].Message,
Tokens: resp.Usage.CompletionTokens,
}
err = app.saveContext(resContext)
if err != nil {
s.log.Error(err)
return
}
}()
go func() {
records := &data.ChatRecord{
UserMsg: in.Message,
UserMsgTokens: currTokens,
UserMsgKeywords: keywords,
AIMsg: resp.Choices[0].Message.Content,
AIMsgTokens: resp.Usage.CompletionTokens,
ReqTokens: tokens,
CreateAt: time.Now().Unix(),
}
err := s.data.Add(records)
if err != nil {
s.log.Error(err)
return
}
//保存到向量数据库
if len(keywords) > 0 {
list := []*vector_data.ChatRecord{
{
ID: strconv.FormatInt(records.ID, 10),
KVs: map[string]string{
"keywords": strings.Join(keywords, ","),
},
},
}
err = s.vectorData.UpsertData(context.Background(), list)
if err != nil {
s.log.Error(err)
return
}
}
}()
return res, err
}
func (s *chatService) ChatCompletionStream(in *proto.ChatCompletionRequest, stream proto.Chat_ChatCompletionStreamServer) error {
redisContextCache := chat_context.NewRedisCache()
defer redisContextCache.Close()
app := s.newApp(in, redisContextCache)
//敏感词过滤
ok, msg, err := app.sensitive(in)
if err != nil {
s.busMetrics.ErrQuestionsTotalCounter.Inc()
return nil, err
}
if !ok {
s.busMetrics.SensitiveQuestionsTotalCounter.Inc()
return app.buildChatCompletionResponse(msg), nil
}
keywords := app.keywords(in)
if len(keywords) > 0 {
s.busMetrics.KeywordsQuestionsTotalCounter.Inc()
}
req, _, _, _, err := app.buildChatCompletionRequest(in, false)
if err != nil {
s.busMetrics.ErrQuestionsTotalCounter.Inc()
return nil, err
}
questionEmbedding, cachedRecord := s.searchCachedAnswer(ctx, in.Message)
if cachedRecord != nil {
return app.buildChatCompletionResponse(cachedRecord.Answer), nil
}
client := app.getOpenaiClient()
resp, err := client.CreateChatCompletion(ctx, req)
if err != nil {
s.busMetrics.ErrQuestionsTotalCounter.Inc()
s.log.Error(err)
return nil, err
}
res := &proto.ChatCompletionResponse{}
bytes, err := json.Marshal(resp)
if err != nil {
s.busMetrics.ErrQuestionsTotalCounter.Inc()
return nil, err
}
if err = jsonpb.UnmarshalString(string(bytes), res); err != nil {
s.busMetrics.ErrQuestionsTotalCounter.Inc()
return nil, err
}
if len(resp.Choices) > 0 {
if err = s.persistQA(ctx, questionEmbedding, in.Message, resp.Choices[0].Message.Content); err != nil {
s.log.Error(err)
} else {
s.busMetrics.QuestionsTotalCounter.Inc()
}
}
return res, nil
}
func (s *chatService) ChatCompletionStream(in *proto.ChatCompletionRequest, stream proto.Chat_ChatCompletionStreamServer) error {
app := s.newApp(in)
ok, msg, err := app.sensitive(in)
if err != nil {
s.busMetrics.ErrQuestionsTotalCounter.Inc()
return err
}
if !ok {
s.busMetrics.SensitiveQuestionsTotalCounter.Inc()
resId := uuid.New().String()
startRes := app.buildChatCompletionStreamResponse(resId, "", "")
endRes := app.buildChatCompletionStreamResponse(resId, "", "stop")
err = stream.Send(startRes)
if err != nil {
s.log.Error(err)
resID := uuid.New().String()
if err = stream.Send(app.buildChatCompletionStreamResponse(resID, "", "")); err != nil {
return err
}
resList := app.buildChatCompletionStreamResponseList(resId, msg)
for _, res := range resList {
err = stream.Send(res)
if err != nil {
s.log.Error(err)
for _, res := range app.buildChatCompletionStreamResponseList(resID, msg) {
if err = stream.Send(res); err != nil {
return err
}
}
err = stream.Send(endRes)
if err != nil {
s.log.Error(err)
return err
}
return nil
return stream.Send(app.buildChatCompletionStreamResponse(resID, "", "stop"))
}
//关键词提取
keywords := app.keywords(in)
if len(keywords) > 0 {
s.busMetrics.KeywordsQuestionsTotalCounter.Inc()
idStr, score, err := s.vectorData.QueryData(context.Background(), map[string][]string{"keywords": {strings.Join(keywords, ",")}})
if err != nil {
s.log.Error(err)
} else if score > s.config.Vector.Threshold {
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
s.log.Error(err)
} else {
record, err := s.data.GetById(id)
if err != nil {
s.log.Error(err)
} else {
resId := uuid.New().String()
startRes := app.buildChatCompletionStreamResponse(resId, "", "")
endRes := app.buildChatCompletionStreamResponse(resId, "", "stop")
err = stream.Send(startRes)
if err != nil {
s.log.Error(err)
return err
}
resList := app.buildChatCompletionStreamResponseList(resId, record.AIMsg)
for _, res := range resList {
err = stream.Send(res)
if err != nil {
s.log.Error(err)
return err
}
}
err = stream.Send(endRes)
if err != nil {
s.log.Error(err)
return err
}
return nil
}
}
req, _, _, _, err := app.buildChatCompletionRequest(in, true)
if err != nil {
s.busMetrics.ErrQuestionsTotalCounter.Inc()
return err
}
questionEmbedding, cachedRecord := s.searchCachedAnswer(stream.Context(), in.Message)
if cachedRecord != nil {
if err = stream.Send(app.buildChatCompletionStreamResponse(cachedRecord.ID, "", "")); err != nil {
return err
}
for _, res := range app.buildChatCompletionStreamResponseList(cachedRecord.ID, cachedRecord.Answer) {
if err = stream.Send(res); err != nil {
return err
}
}
return stream.Send(app.buildChatCompletionStreamResponse(cachedRecord.ID, "", "stop"))
}
client := app.getOpenaiClient()
req, tokens, currTokens, currMessage, err := app.buildChatCompletionRequest(in, false)
chatStream, err := client.CreateChatCompletionStream(stream.Context(), req)
if err != nil {
s.busMetrics.ErrQuestionsTotalCounter.Inc()
@@ -245,109 +152,106 @@ func (s *chatService) ChatCompletionStream(in *proto.ChatCompletionRequest, stre
return err
}
defer chatStream.Close()
completionContent := ""
resultID := ""
for {
resp, err := chatStream.Recv()
if err != nil && err != io.EOF {
s.busMetrics.ErrQuestionsTotalCounter.Inc()
s.log.Error(err)
return err
}
if err == io.EOF {
break
}
if resultID == "" {
resultID = resp.ID
}
completionContent += resp.Choices[0].Delta.Content
res := &proto.ChatCompletionStreamResponse{}
bytes, err := json.Marshal(resp)
if err != nil {
s.log.Error(err)
s.busMetrics.ErrQuestionsTotalCounter.Inc()
return err
}
err = jsonpb.UnmarshalString(string(bytes), res)
if err != nil {
s.log.Error(err)
if err = jsonpb.UnmarshalString(string(bytes), res); err != nil {
s.busMetrics.ErrQuestionsTotalCounter.Inc()
return err
}
err = stream.Send(res)
if err != nil {
s.log.Error(err)
if err = stream.Send(res); err != nil {
return err
}
}
model := s.config.Chat.Model
if in.ChatParam != nil && in.ChatParam.Model != "" {
model = in.ChatParam.Model
}
resultMessage := openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
Content: completionContent,
}
model := s.config.Chat.Model
if in.ChatParam != nil && in.ChatParam.Model != "" {
model = in.ChatParam.Model
}
resultTokens, err := tokenizer.GetTokens(&resultMessage, model)
if err != nil {
if _, err = tokenizer.GetTokens(&resultMessage, model); err != nil {
s.busMetrics.ErrQuestionsTotalCounter.Inc()
s.log.Error(err)
return err
}
go func() {
reqContext := &chat_context.ChatMessage{
ID: in.Id,
PID: in.Pid,
Message: currMessage,
Tokens: currTokens,
}
err := app.saveContext(reqContext)
if err != nil {
s.log.Error(err)
return
}
resContext := &chat_context.ChatMessage{
ID: resultID,
PID: reqContext.ID,
Message: resultMessage,
Tokens: resultTokens,
}
err = app.saveContext(resContext)
if err != nil {
s.log.Error(err)
return
}
}()
go func() {
if err = s.persistQA(stream.Context(), questionEmbedding, in.Message, completionContent); err != nil {
s.log.Error(err)
} else {
s.busMetrics.QuestionsTotalCounter.Inc()
records := &data.ChatRecord{
UserMsg: in.Message,
UserMsgTokens: currTokens,
UserMsgKeywords: keywords,
AIMsg: completionContent,
AIMsgTokens: resultTokens,
ReqTokens: tokens,
CreateAt: time.Now().Unix(),
}
err := s.data.Add(records)
if err != nil {
s.log.Error(err)
return
}
//保存到向量数据库
if len(keywords) > 0 {
list := []*vector_data.ChatRecord{
{
ID: strconv.FormatInt(records.ID, 10),
KVs: map[string]string{
"keywords": strings.Join(keywords, ","),
},
},
}
err = s.vectorData.UpsertData(context.Background(), list)
if err != nil {
s.log.Error(err)
return
}
}
}()
}
return nil
}
func (s *chatService) searchCachedAnswer(ctx context.Context, question string) ([]float32, *data.ChatRecord) {
embeddingVector, err := s.embedder.Embed(ctx, question)
if err != nil {
s.log.Error(err)
return nil, nil
}
searchRes, err := s.faiss.Search(ctx, embeddingVector, s.config.Faiss.SearchK)
if err != nil {
s.log.Error(err)
return embeddingVector, nil
}
if searchRes == nil || len(searchRes.IDs) == 0 || len(searchRes.SimilarityScores) == 0 {
return embeddingVector, nil
}
limit := len(searchRes.IDs)
if len(searchRes.SimilarityScores) < limit {
limit = len(searchRes.SimilarityScores)
}
for i := 0; i < limit; i++ {
if searchRes.IDs[i] < 0 || searchRes.SimilarityScores[i] < s.config.Faiss.SimilarityThreshold {
continue
}
record, err := s.data.GetById(strconv.FormatInt(searchRes.IDs[i], 10))
if err != nil {
s.log.Error(err)
return embeddingVector, nil
}
if record != nil {
return embeddingVector, record
}
}
return embeddingVector, nil
}
func (s *chatService) persistQA(ctx context.Context, questionEmbedding []float32, question, answer string) error {
if len(questionEmbedding) == 0 {
vector, err := s.embedder.Embed(ctx, question)
if err != nil {
return err
}
questionEmbedding = vector
}
id, err := s.faiss.Insert(ctx, questionEmbedding)
if err != nil {
return err
}
return s.data.Add(&data.ChatRecord{
ID: id,
Question: question,
Answer: answer,
})
}