package server import ( "ai-chat-service/chat-server/data" metrics_bus "ai-chat-service/chat-server/metrics-bus" "ai-chat-service/pkg/config" "ai-chat-service/pkg/log" "ai-chat-service/proto" "ai-chat-service/services/embedding" "ai-chat-service/services/faiss" "context" "encoding/json" "io" "strconv" "github.com/golang/protobuf/jsonpb" "github.com/google/uuid" "github.com/sashabaranov/go-openai" ) const ( replySourceSemanticMatch = "semantic_match" replySourceLLM = "llm" ) type chatService struct { proto.UnimplementedChatServer config *config.Config log log.ILogger data data.IChatRecordsData embedder embedding.Embedder faiss faiss.Client busMetrics *metrics_bus.BusMetrics } func NewChatService(data data.IChatRecordsData, embedder embedding.Embedder, faissClient faiss.Client, config *config.Config, log log.ILogger, busMetrics *metrics_bus.BusMetrics) proto.ChatServer { return &chatService{ config: config, log: log, data: data, embedder: embedder, faiss: faissClient, busMetrics: busMetrics, } } func (s *chatService) ChatCompletion(ctx context.Context, in *proto.ChatCompletionRequest) (*proto.ChatCompletionResponse, error) { app := s.newApp(in) ok, msg, err := app.sensitive(in) if err != nil { s.busMetrics.ErrQuestionsTotalCounter.Inc() return nil, err } if !ok { s.busMetrics.SensitiveQuestionsTotalCounter.Inc() res := app.buildChatCompletionResponse(msg) promptTokens, tokenErr := app.countMessageTokens(openai.ChatMessageRoleUser, in.Message) if tokenErr != nil { s.log.Error(tokenErr) return res, nil } usage, tokenErr := app.buildUsage(promptTokens, msg) if tokenErr != nil { s.log.Error(tokenErr) return res, nil } res.Usage = usage return res, nil } keywords := app.keywords(in) if len(keywords) > 0 { s.busMetrics.KeywordsQuestionsTotalCounter.Inc() } req, _, currTokens, _, err := app.buildChatCompletionRequest(in, false) if err != nil { s.busMetrics.ErrQuestionsTotalCounter.Inc() return nil, err } questionEmbedding, cachedRecord := s.searchCachedAnswer(ctx, in.Message) if cachedRecord != nil { res := app.buildChatCompletionResponse(cachedRecord.Answer) res.Usage = app.buildZeroUsage() res.Source = replySourceSemanticMatch return res, nil } client := app.getOpenaiClient() resp, err := client.CreateChatCompletion(ctx, req) if err != nil { s.busMetrics.ErrQuestionsTotalCounter.Inc() s.log.Error(err) return nil, err } res := &proto.ChatCompletionResponse{} bytes, err := json.Marshal(resp) if err != nil { s.busMetrics.ErrQuestionsTotalCounter.Inc() return nil, err } if err = jsonpb.UnmarshalString(string(bytes), res); err != nil { s.busMetrics.ErrQuestionsTotalCounter.Inc() return nil, err } answer := "" if len(resp.Choices) > 0 { answer = resp.Choices[0].Message.Content } usage, tokenErr := app.buildUsage(currTokens, answer) if tokenErr != nil { s.log.Error(tokenErr) } else { res.Usage = usage } res.Source = replySourceLLM if len(resp.Choices) > 0 { if err = s.persistQA(ctx, questionEmbedding, in.Message, answer); err != nil { s.log.Error(err) } else { s.busMetrics.QuestionsTotalCounter.Inc() } } return res, nil } func (s *chatService) ChatCompletionStream(in *proto.ChatCompletionRequest, stream proto.Chat_ChatCompletionStreamServer) error { app := s.newApp(in) ok, msg, err := app.sensitive(in) if err != nil { s.busMetrics.ErrQuestionsTotalCounter.Inc() return err } if !ok { s.busMetrics.SensitiveQuestionsTotalCounter.Inc() resID := uuid.New().String() start := app.buildChatCompletionStreamResponse(resID, "", "") if err = stream.Send(start); err != nil { return err } for _, res := range app.buildChatCompletionStreamResponseList(resID, msg) { if err = stream.Send(res); err != nil { return err } } final := app.buildChatCompletionStreamResponse(resID, "", "stop") promptTokens, tokenErr := app.countMessageTokens(openai.ChatMessageRoleUser, in.Message) if tokenErr != nil { s.log.Error(tokenErr) } else { usage, tokenErr := app.buildUsage(promptTokens, msg) if tokenErr != nil { s.log.Error(tokenErr) } else { final.Usage = usage } } return stream.Send(final) } keywords := app.keywords(in) if len(keywords) > 0 { s.busMetrics.KeywordsQuestionsTotalCounter.Inc() } req, _, currTokens, _, err := app.buildChatCompletionRequest(in, true) if err != nil { s.busMetrics.ErrQuestionsTotalCounter.Inc() return err } questionEmbedding, cachedRecord := s.searchCachedAnswer(stream.Context(), in.Message) if cachedRecord != nil { start := app.buildChatCompletionStreamResponse(cachedRecord.ID, "", "") start.Source = replySourceSemanticMatch if err = stream.Send(start); err != nil { return err } for _, res := range app.buildChatCompletionStreamResponseList(cachedRecord.ID, cachedRecord.Answer) { if err = stream.Send(res); err != nil { return err } } final := app.buildChatCompletionStreamResponse(cachedRecord.ID, "", "stop") final.Source = replySourceSemanticMatch final.Usage = app.buildZeroUsage() return stream.Send(final) } client := app.getOpenaiClient() chatStream, err := client.CreateChatCompletionStream(stream.Context(), req) if err != nil { s.busMetrics.ErrQuestionsTotalCounter.Inc() s.log.Error(err) return err } defer chatStream.Close() completionContent := "" responseID := "" for { resp, err := chatStream.Recv() if err != nil && err != io.EOF { s.busMetrics.ErrQuestionsTotalCounter.Inc() s.log.Error(err) return err } if err == io.EOF { break } if resp.ID != "" { responseID = resp.ID } completionContent += resp.Choices[0].Delta.Content res := &proto.ChatCompletionStreamResponse{} bytes, err := json.Marshal(resp) if err != nil { s.busMetrics.ErrQuestionsTotalCounter.Inc() return err } if err = jsonpb.UnmarshalString(string(bytes), res); err != nil { s.busMetrics.ErrQuestionsTotalCounter.Inc() return err } res.Source = replySourceLLM if err = stream.Send(res); err != nil { return err } } usage, tokenErr := app.buildUsage(currTokens, completionContent) if tokenErr != nil { s.busMetrics.ErrQuestionsTotalCounter.Inc() return tokenErr } if responseID == "" { responseID = uuid.New().String() } final := app.buildChatCompletionStreamResponse(responseID, "", "stop") final.Usage = usage final.Source = replySourceLLM if err = stream.Send(final); err != nil { return err } if err = s.persistQA(stream.Context(), questionEmbedding, in.Message, completionContent); err != nil { s.log.Error(err) } else { s.busMetrics.QuestionsTotalCounter.Inc() } return nil } func (s *chatService) searchCachedAnswer(ctx context.Context, question string) ([]float32, *data.ChatRecord) { embeddingVector, err := s.embedder.Embed(ctx, question) if err != nil { s.log.Error(err) return nil, nil } searchRes, err := s.faiss.Search(ctx, embeddingVector, s.config.Faiss.SearchK) if err != nil { s.log.Error(err) return embeddingVector, nil } if searchRes == nil || len(searchRes.IDs) == 0 || len(searchRes.SimilarityScores) == 0 { return embeddingVector, nil } limit := len(searchRes.IDs) if len(searchRes.SimilarityScores) < limit { limit = len(searchRes.SimilarityScores) } for i := 0; i < limit; i++ { if searchRes.IDs[i] < 0 || searchRes.SimilarityScores[i] < s.config.Faiss.SimilarityThreshold { continue } record, err := s.data.GetById(strconv.FormatInt(searchRes.IDs[i], 10)) if err != nil { s.log.Error(err) return embeddingVector, nil } if record != nil { return embeddingVector, record } } return embeddingVector, nil } func (s *chatService) persistQA(ctx context.Context, questionEmbedding []float32, question, answer string) error { if len(questionEmbedding) == 0 { vector, err := s.embedder.Embed(ctx, question) if err != nil { return err } questionEmbedding = vector } id, err := s.faiss.Insert(ctx, questionEmbedding) if err != nil { return err } return s.data.Add(&data.ChatRecord{ ID: id, Question: question, Answer: answer, }) }