324 lines
8.3 KiB
Go
324 lines
8.3 KiB
Go
package server
|
|
|
|
import (
|
|
"ai-chat-service/chat-server/data"
|
|
metrics_bus "ai-chat-service/chat-server/metrics-bus"
|
|
"ai-chat-service/pkg/config"
|
|
"ai-chat-service/pkg/log"
|
|
"ai-chat-service/proto"
|
|
"ai-chat-service/services/embedding"
|
|
"ai-chat-service/services/faiss"
|
|
"context"
|
|
"encoding/json"
|
|
"io"
|
|
"strconv"
|
|
|
|
"github.com/golang/protobuf/jsonpb"
|
|
"github.com/google/uuid"
|
|
"github.com/sashabaranov/go-openai"
|
|
)
|
|
|
|
const (
|
|
replySourceSemanticMatch = "semantic_match"
|
|
replySourceLLM = "llm"
|
|
)
|
|
|
|
type chatService struct {
|
|
proto.UnimplementedChatServer
|
|
config *config.Config
|
|
log log.ILogger
|
|
data data.IChatRecordsData
|
|
embedder embedding.Embedder
|
|
faiss faiss.Client
|
|
busMetrics *metrics_bus.BusMetrics
|
|
}
|
|
|
|
func NewChatService(data data.IChatRecordsData, embedder embedding.Embedder, faissClient faiss.Client, config *config.Config, log log.ILogger, busMetrics *metrics_bus.BusMetrics) proto.ChatServer {
|
|
return &chatService{
|
|
config: config,
|
|
log: log,
|
|
data: data,
|
|
embedder: embedder,
|
|
faiss: faissClient,
|
|
busMetrics: busMetrics,
|
|
}
|
|
}
|
|
|
|
func (s *chatService) ChatCompletion(ctx context.Context, in *proto.ChatCompletionRequest) (*proto.ChatCompletionResponse, error) {
|
|
app := s.newApp(in)
|
|
|
|
ok, msg, err := app.sensitive(in)
|
|
if err != nil {
|
|
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
|
return nil, err
|
|
}
|
|
if !ok {
|
|
s.busMetrics.SensitiveQuestionsTotalCounter.Inc()
|
|
res := app.buildChatCompletionResponse(msg)
|
|
promptTokens, tokenErr := app.countMessageTokens(openai.ChatMessageRoleUser, in.Message)
|
|
if tokenErr != nil {
|
|
s.log.Error(tokenErr)
|
|
return res, nil
|
|
}
|
|
usage, tokenErr := app.buildUsage(promptTokens, msg)
|
|
if tokenErr != nil {
|
|
s.log.Error(tokenErr)
|
|
return res, nil
|
|
}
|
|
res.Usage = usage
|
|
return res, nil
|
|
}
|
|
|
|
keywords := app.keywords(in)
|
|
if len(keywords) > 0 {
|
|
s.busMetrics.KeywordsQuestionsTotalCounter.Inc()
|
|
}
|
|
|
|
req, _, currTokens, _, err := app.buildChatCompletionRequest(in, false)
|
|
if err != nil {
|
|
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
|
return nil, err
|
|
}
|
|
|
|
questionEmbedding, cachedRecord := s.searchCachedAnswer(ctx, in.Message)
|
|
if cachedRecord != nil {
|
|
res := app.buildChatCompletionResponse(cachedRecord.Answer)
|
|
usage, tokenErr := app.buildUsage(currTokens, cachedRecord.Answer)
|
|
if tokenErr != nil {
|
|
s.log.Error(tokenErr)
|
|
} else {
|
|
res.Usage = usage
|
|
}
|
|
res.Source = replySourceSemanticMatch
|
|
return res, nil
|
|
}
|
|
|
|
client := app.getOpenaiClient()
|
|
resp, err := client.CreateChatCompletion(ctx, req)
|
|
if err != nil {
|
|
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
|
s.log.Error(err)
|
|
return nil, err
|
|
}
|
|
|
|
res := &proto.ChatCompletionResponse{}
|
|
bytes, err := json.Marshal(resp)
|
|
if err != nil {
|
|
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
|
return nil, err
|
|
}
|
|
if err = jsonpb.UnmarshalString(string(bytes), res); err != nil {
|
|
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
|
return nil, err
|
|
}
|
|
|
|
answer := ""
|
|
if len(resp.Choices) > 0 {
|
|
answer = resp.Choices[0].Message.Content
|
|
}
|
|
usage, tokenErr := app.buildUsage(currTokens, answer)
|
|
if tokenErr != nil {
|
|
s.log.Error(tokenErr)
|
|
} else {
|
|
res.Usage = usage
|
|
}
|
|
res.Source = replySourceLLM
|
|
|
|
if len(resp.Choices) > 0 {
|
|
if err = s.persistQA(ctx, questionEmbedding, in.Message, answer); err != nil {
|
|
s.log.Error(err)
|
|
} else {
|
|
s.busMetrics.QuestionsTotalCounter.Inc()
|
|
}
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
func (s *chatService) ChatCompletionStream(in *proto.ChatCompletionRequest, stream proto.Chat_ChatCompletionStreamServer) error {
|
|
app := s.newApp(in)
|
|
|
|
ok, msg, err := app.sensitive(in)
|
|
if err != nil {
|
|
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
|
return err
|
|
}
|
|
if !ok {
|
|
s.busMetrics.SensitiveQuestionsTotalCounter.Inc()
|
|
resID := uuid.New().String()
|
|
start := app.buildChatCompletionStreamResponse(resID, "", "")
|
|
if err = stream.Send(start); err != nil {
|
|
return err
|
|
}
|
|
for _, res := range app.buildChatCompletionStreamResponseList(resID, msg) {
|
|
if err = stream.Send(res); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
final := app.buildChatCompletionStreamResponse(resID, "", "stop")
|
|
promptTokens, tokenErr := app.countMessageTokens(openai.ChatMessageRoleUser, in.Message)
|
|
if tokenErr != nil {
|
|
s.log.Error(tokenErr)
|
|
} else {
|
|
usage, tokenErr := app.buildUsage(promptTokens, msg)
|
|
if tokenErr != nil {
|
|
s.log.Error(tokenErr)
|
|
} else {
|
|
final.Usage = usage
|
|
}
|
|
}
|
|
return stream.Send(final)
|
|
}
|
|
|
|
keywords := app.keywords(in)
|
|
if len(keywords) > 0 {
|
|
s.busMetrics.KeywordsQuestionsTotalCounter.Inc()
|
|
}
|
|
|
|
req, _, currTokens, _, err := app.buildChatCompletionRequest(in, true)
|
|
if err != nil {
|
|
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
|
return err
|
|
}
|
|
|
|
questionEmbedding, cachedRecord := s.searchCachedAnswer(stream.Context(), in.Message)
|
|
if cachedRecord != nil {
|
|
start := app.buildChatCompletionStreamResponse(cachedRecord.ID, "", "")
|
|
start.Source = replySourceSemanticMatch
|
|
if err = stream.Send(start); err != nil {
|
|
return err
|
|
}
|
|
for _, res := range app.buildChatCompletionStreamResponseList(cachedRecord.ID, cachedRecord.Answer) {
|
|
if err = stream.Send(res); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
final := app.buildChatCompletionStreamResponse(cachedRecord.ID, "", "stop")
|
|
final.Source = replySourceSemanticMatch
|
|
usage, tokenErr := app.buildUsage(currTokens, cachedRecord.Answer)
|
|
if tokenErr != nil {
|
|
s.log.Error(tokenErr)
|
|
} else {
|
|
final.Usage = usage
|
|
}
|
|
return stream.Send(final)
|
|
}
|
|
|
|
client := app.getOpenaiClient()
|
|
chatStream, err := client.CreateChatCompletionStream(stream.Context(), req)
|
|
if err != nil {
|
|
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
|
s.log.Error(err)
|
|
return err
|
|
}
|
|
defer chatStream.Close()
|
|
|
|
completionContent := ""
|
|
responseID := ""
|
|
for {
|
|
resp, err := chatStream.Recv()
|
|
if err != nil && err != io.EOF {
|
|
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
|
s.log.Error(err)
|
|
return err
|
|
}
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if resp.ID != "" {
|
|
responseID = resp.ID
|
|
}
|
|
completionContent += resp.Choices[0].Delta.Content
|
|
|
|
res := &proto.ChatCompletionStreamResponse{}
|
|
bytes, err := json.Marshal(resp)
|
|
if err != nil {
|
|
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
|
return err
|
|
}
|
|
if err = jsonpb.UnmarshalString(string(bytes), res); err != nil {
|
|
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
|
return err
|
|
}
|
|
res.Source = replySourceLLM
|
|
if err = stream.Send(res); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
usage, tokenErr := app.buildUsage(currTokens, completionContent)
|
|
if tokenErr != nil {
|
|
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
|
return tokenErr
|
|
}
|
|
if responseID == "" {
|
|
responseID = uuid.New().String()
|
|
}
|
|
final := app.buildChatCompletionStreamResponse(responseID, "", "stop")
|
|
final.Usage = usage
|
|
final.Source = replySourceLLM
|
|
if err = stream.Send(final); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err = s.persistQA(stream.Context(), questionEmbedding, in.Message, completionContent); err != nil {
|
|
s.log.Error(err)
|
|
} else {
|
|
s.busMetrics.QuestionsTotalCounter.Inc()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *chatService) searchCachedAnswer(ctx context.Context, question string) ([]float32, *data.ChatRecord) {
|
|
embeddingVector, err := s.embedder.Embed(ctx, question)
|
|
if err != nil {
|
|
s.log.Error(err)
|
|
return nil, nil
|
|
}
|
|
|
|
searchRes, err := s.faiss.Search(ctx, embeddingVector, s.config.Faiss.SearchK)
|
|
if err != nil {
|
|
s.log.Error(err)
|
|
return embeddingVector, nil
|
|
}
|
|
if searchRes == nil || len(searchRes.IDs) == 0 || len(searchRes.SimilarityScores) == 0 {
|
|
return embeddingVector, nil
|
|
}
|
|
limit := len(searchRes.IDs)
|
|
if len(searchRes.SimilarityScores) < limit {
|
|
limit = len(searchRes.SimilarityScores)
|
|
}
|
|
for i := 0; i < limit; i++ {
|
|
if searchRes.IDs[i] < 0 || searchRes.SimilarityScores[i] < s.config.Faiss.SimilarityThreshold {
|
|
continue
|
|
}
|
|
record, err := s.data.GetById(strconv.FormatInt(searchRes.IDs[i], 10))
|
|
if err != nil {
|
|
s.log.Error(err)
|
|
return embeddingVector, nil
|
|
}
|
|
if record != nil {
|
|
return embeddingVector, record
|
|
}
|
|
}
|
|
return embeddingVector, nil
|
|
}
|
|
|
|
func (s *chatService) persistQA(ctx context.Context, questionEmbedding []float32, question, answer string) error {
|
|
if len(questionEmbedding) == 0 {
|
|
vector, err := s.embedder.Embed(ctx, question)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
questionEmbedding = vector
|
|
}
|
|
id, err := s.faiss.Insert(ctx, questionEmbedding)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return s.data.Add(&data.ChatRecord{
|
|
ID: id,
|
|
Question: question,
|
|
Answer: answer,
|
|
})
|
|
}
|