Files
mchat/ai-chat-service/chat-server/server/server.go
2026-04-10 11:12:10 +00:00

258 lines
6.8 KiB
Go

package server
import (
"ai-chat-service/chat-server/data"
metrics_bus "ai-chat-service/chat-server/metrics-bus"
"ai-chat-service/pkg/config"
"ai-chat-service/pkg/log"
"ai-chat-service/proto"
"ai-chat-service/services/embedding"
"ai-chat-service/services/faiss"
"ai-chat-service/services/tokenizer"
"context"
"encoding/json"
"io"
"strconv"
"github.com/golang/protobuf/jsonpb"
"github.com/google/uuid"
"github.com/sashabaranov/go-openai"
)
type chatService struct {
proto.UnimplementedChatServer
config *config.Config
log log.ILogger
data data.IChatRecordsData
embedder embedding.Embedder
faiss faiss.Client
busMetrics *metrics_bus.BusMetrics
}
func NewChatService(data data.IChatRecordsData, embedder embedding.Embedder, faissClient faiss.Client, config *config.Config, log log.ILogger, busMetrics *metrics_bus.BusMetrics) proto.ChatServer {
return &chatService{
config: config,
log: log,
data: data,
embedder: embedder,
faiss: faissClient,
busMetrics: busMetrics,
}
}
func (s *chatService) ChatCompletion(ctx context.Context, in *proto.ChatCompletionRequest) (*proto.ChatCompletionResponse, error) {
app := s.newApp(in)
ok, msg, err := app.sensitive(in)
if err != nil {
s.busMetrics.ErrQuestionsTotalCounter.Inc()
return nil, err
}
if !ok {
s.busMetrics.SensitiveQuestionsTotalCounter.Inc()
return app.buildChatCompletionResponse(msg), nil
}
keywords := app.keywords(in)
if len(keywords) > 0 {
s.busMetrics.KeywordsQuestionsTotalCounter.Inc()
}
req, _, _, _, err := app.buildChatCompletionRequest(in, false)
if err != nil {
s.busMetrics.ErrQuestionsTotalCounter.Inc()
return nil, err
}
questionEmbedding, cachedRecord := s.searchCachedAnswer(ctx, in.Message)
if cachedRecord != nil {
return app.buildChatCompletionResponse(cachedRecord.Answer), nil
}
client := app.getOpenaiClient()
resp, err := client.CreateChatCompletion(ctx, req)
if err != nil {
s.busMetrics.ErrQuestionsTotalCounter.Inc()
s.log.Error(err)
return nil, err
}
res := &proto.ChatCompletionResponse{}
bytes, err := json.Marshal(resp)
if err != nil {
s.busMetrics.ErrQuestionsTotalCounter.Inc()
return nil, err
}
if err = jsonpb.UnmarshalString(string(bytes), res); err != nil {
s.busMetrics.ErrQuestionsTotalCounter.Inc()
return nil, err
}
if len(resp.Choices) > 0 {
if err = s.persistQA(ctx, questionEmbedding, in.Message, resp.Choices[0].Message.Content); err != nil {
s.log.Error(err)
} else {
s.busMetrics.QuestionsTotalCounter.Inc()
}
}
return res, nil
}
func (s *chatService) ChatCompletionStream(in *proto.ChatCompletionRequest, stream proto.Chat_ChatCompletionStreamServer) error {
app := s.newApp(in)
ok, msg, err := app.sensitive(in)
if err != nil {
s.busMetrics.ErrQuestionsTotalCounter.Inc()
return err
}
if !ok {
s.busMetrics.SensitiveQuestionsTotalCounter.Inc()
resID := uuid.New().String()
if err = stream.Send(app.buildChatCompletionStreamResponse(resID, "", "")); err != nil {
return err
}
for _, res := range app.buildChatCompletionStreamResponseList(resID, msg) {
if err = stream.Send(res); err != nil {
return err
}
}
return stream.Send(app.buildChatCompletionStreamResponse(resID, "", "stop"))
}
keywords := app.keywords(in)
if len(keywords) > 0 {
s.busMetrics.KeywordsQuestionsTotalCounter.Inc()
}
req, _, _, _, err := app.buildChatCompletionRequest(in, true)
if err != nil {
s.busMetrics.ErrQuestionsTotalCounter.Inc()
return err
}
questionEmbedding, cachedRecord := s.searchCachedAnswer(stream.Context(), in.Message)
if cachedRecord != nil {
if err = stream.Send(app.buildChatCompletionStreamResponse(cachedRecord.ID, "", "")); err != nil {
return err
}
for _, res := range app.buildChatCompletionStreamResponseList(cachedRecord.ID, cachedRecord.Answer) {
if err = stream.Send(res); err != nil {
return err
}
}
return stream.Send(app.buildChatCompletionStreamResponse(cachedRecord.ID, "", "stop"))
}
client := app.getOpenaiClient()
chatStream, err := client.CreateChatCompletionStream(stream.Context(), req)
if err != nil {
s.busMetrics.ErrQuestionsTotalCounter.Inc()
s.log.Error(err)
return err
}
defer chatStream.Close()
completionContent := ""
for {
resp, err := chatStream.Recv()
if err != nil && err != io.EOF {
s.busMetrics.ErrQuestionsTotalCounter.Inc()
s.log.Error(err)
return err
}
if err == io.EOF {
break
}
completionContent += resp.Choices[0].Delta.Content
res := &proto.ChatCompletionStreamResponse{}
bytes, err := json.Marshal(resp)
if err != nil {
s.busMetrics.ErrQuestionsTotalCounter.Inc()
return err
}
if err = jsonpb.UnmarshalString(string(bytes), res); err != nil {
s.busMetrics.ErrQuestionsTotalCounter.Inc()
return err
}
if err = stream.Send(res); err != nil {
return err
}
}
model := s.config.Chat.Model
if in.ChatParam != nil && in.ChatParam.Model != "" {
model = in.ChatParam.Model
}
resultMessage := openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
Content: completionContent,
}
if _, err = tokenizer.GetTokens(&resultMessage, model); err != nil {
s.busMetrics.ErrQuestionsTotalCounter.Inc()
return err
}
if err = s.persistQA(stream.Context(), questionEmbedding, in.Message, completionContent); err != nil {
s.log.Error(err)
} else {
s.busMetrics.QuestionsTotalCounter.Inc()
}
return nil
}
func (s *chatService) searchCachedAnswer(ctx context.Context, question string) ([]float32, *data.ChatRecord) {
embeddingVector, err := s.embedder.Embed(ctx, question)
if err != nil {
s.log.Error(err)
return nil, nil
}
searchRes, err := s.faiss.Search(ctx, embeddingVector, s.config.Faiss.SearchK)
if err != nil {
s.log.Error(err)
return embeddingVector, nil
}
if searchRes == nil || len(searchRes.IDs) == 0 || len(searchRes.SimilarityScores) == 0 {
return embeddingVector, nil
}
limit := len(searchRes.IDs)
if len(searchRes.SimilarityScores) < limit {
limit = len(searchRes.SimilarityScores)
}
for i := 0; i < limit; i++ {
if searchRes.IDs[i] < 0 || searchRes.SimilarityScores[i] < s.config.Faiss.SimilarityThreshold {
continue
}
record, err := s.data.GetById(strconv.FormatInt(searchRes.IDs[i], 10))
if err != nil {
s.log.Error(err)
return embeddingVector, nil
}
if record != nil {
return embeddingVector, record
}
}
return embeddingVector, nil
}
func (s *chatService) persistQA(ctx context.Context, questionEmbedding []float32, question, answer string) error {
if len(questionEmbedding) == 0 {
vector, err := s.embedder.Embed(ctx, question)
if err != nil {
return err
}
questionEmbedding = vector
}
id, err := s.faiss.Insert(ctx, questionEmbedding)
if err != nil {
return err
}
return s.data.Add(&data.ChatRecord{
ID: id,
Question: question,
Answer: answer,
})
}