faiss server
This commit is contained in:
@@ -149,6 +149,30 @@ func (a *app) buildChatCompletionResponse(msg string) *proto.ChatCompletionRespo
|
||||
}
|
||||
}
|
||||
|
||||
func (a *app) countMessageTokens(role, content string) (int, error) {
|
||||
message := openai.ChatCompletionMessage{
|
||||
Role: role,
|
||||
Content: content,
|
||||
}
|
||||
return tokenizer.GetTokens(&message, a.openaiConf.Model)
|
||||
}
|
||||
|
||||
func (a *app) buildUsage(promptTokens int, answer string) (*proto.Usage, error) {
|
||||
completionTokens := 0
|
||||
if answer != "" {
|
||||
tokens, err := a.countMessageTokens(openai.ChatMessageRoleAssistant, answer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
completionTokens = tokens
|
||||
}
|
||||
return &proto.Usage{
|
||||
PromptTokens: int32(promptTokens),
|
||||
CompletionTokens: int32(completionTokens),
|
||||
TotalTokens: int32(promptTokens + completionTokens),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *app) buildChatCompletionStreamResponse(id, delta, finishReason string) *proto.ChatCompletionStreamResponse {
|
||||
return &proto.ChatCompletionStreamResponse{
|
||||
Id: id,
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"ai-chat-service/proto"
|
||||
"ai-chat-service/services/embedding"
|
||||
"ai-chat-service/services/faiss"
|
||||
"ai-chat-service/services/tokenizer"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
@@ -19,6 +18,11 @@ import (
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
const (
|
||||
replySourceSemanticMatch = "semantic_match"
|
||||
replySourceLLM = "llm"
|
||||
)
|
||||
|
||||
type chatService struct {
|
||||
proto.UnimplementedChatServer
|
||||
config *config.Config
|
||||
@@ -50,7 +54,19 @@ func (s *chatService) ChatCompletion(ctx context.Context, in *proto.ChatCompleti
|
||||
}
|
||||
if !ok {
|
||||
s.busMetrics.SensitiveQuestionsTotalCounter.Inc()
|
||||
return app.buildChatCompletionResponse(msg), nil
|
||||
res := app.buildChatCompletionResponse(msg)
|
||||
promptTokens, tokenErr := app.countMessageTokens(openai.ChatMessageRoleUser, in.Message)
|
||||
if tokenErr != nil {
|
||||
s.log.Error(tokenErr)
|
||||
return res, nil
|
||||
}
|
||||
usage, tokenErr := app.buildUsage(promptTokens, msg)
|
||||
if tokenErr != nil {
|
||||
s.log.Error(tokenErr)
|
||||
return res, nil
|
||||
}
|
||||
res.Usage = usage
|
||||
return res, nil
|
||||
}
|
||||
|
||||
keywords := app.keywords(in)
|
||||
@@ -58,7 +74,7 @@ func (s *chatService) ChatCompletion(ctx context.Context, in *proto.ChatCompleti
|
||||
s.busMetrics.KeywordsQuestionsTotalCounter.Inc()
|
||||
}
|
||||
|
||||
req, _, _, _, err := app.buildChatCompletionRequest(in, false)
|
||||
req, _, currTokens, _, err := app.buildChatCompletionRequest(in, false)
|
||||
if err != nil {
|
||||
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||
return nil, err
|
||||
@@ -66,7 +82,15 @@ func (s *chatService) ChatCompletion(ctx context.Context, in *proto.ChatCompleti
|
||||
|
||||
questionEmbedding, cachedRecord := s.searchCachedAnswer(ctx, in.Message)
|
||||
if cachedRecord != nil {
|
||||
return app.buildChatCompletionResponse(cachedRecord.Answer), nil
|
||||
res := app.buildChatCompletionResponse(cachedRecord.Answer)
|
||||
usage, tokenErr := app.buildUsage(currTokens, cachedRecord.Answer)
|
||||
if tokenErr != nil {
|
||||
s.log.Error(tokenErr)
|
||||
} else {
|
||||
res.Usage = usage
|
||||
}
|
||||
res.Source = replySourceSemanticMatch
|
||||
return res, nil
|
||||
}
|
||||
|
||||
client := app.getOpenaiClient()
|
||||
@@ -88,8 +112,20 @@ func (s *chatService) ChatCompletion(ctx context.Context, in *proto.ChatCompleti
|
||||
return nil, err
|
||||
}
|
||||
|
||||
answer := ""
|
||||
if len(resp.Choices) > 0 {
|
||||
if err = s.persistQA(ctx, questionEmbedding, in.Message, resp.Choices[0].Message.Content); err != nil {
|
||||
answer = resp.Choices[0].Message.Content
|
||||
}
|
||||
usage, tokenErr := app.buildUsage(currTokens, answer)
|
||||
if tokenErr != nil {
|
||||
s.log.Error(tokenErr)
|
||||
} else {
|
||||
res.Usage = usage
|
||||
}
|
||||
res.Source = replySourceLLM
|
||||
|
||||
if len(resp.Choices) > 0 {
|
||||
if err = s.persistQA(ctx, questionEmbedding, in.Message, answer); err != nil {
|
||||
s.log.Error(err)
|
||||
} else {
|
||||
s.busMetrics.QuestionsTotalCounter.Inc()
|
||||
@@ -109,7 +145,8 @@ func (s *chatService) ChatCompletionStream(in *proto.ChatCompletionRequest, stre
|
||||
if !ok {
|
||||
s.busMetrics.SensitiveQuestionsTotalCounter.Inc()
|
||||
resID := uuid.New().String()
|
||||
if err = stream.Send(app.buildChatCompletionStreamResponse(resID, "", "")); err != nil {
|
||||
start := app.buildChatCompletionStreamResponse(resID, "", "")
|
||||
if err = stream.Send(start); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, res := range app.buildChatCompletionStreamResponseList(resID, msg) {
|
||||
@@ -117,7 +154,19 @@ func (s *chatService) ChatCompletionStream(in *proto.ChatCompletionRequest, stre
|
||||
return err
|
||||
}
|
||||
}
|
||||
return stream.Send(app.buildChatCompletionStreamResponse(resID, "", "stop"))
|
||||
final := app.buildChatCompletionStreamResponse(resID, "", "stop")
|
||||
promptTokens, tokenErr := app.countMessageTokens(openai.ChatMessageRoleUser, in.Message)
|
||||
if tokenErr != nil {
|
||||
s.log.Error(tokenErr)
|
||||
} else {
|
||||
usage, tokenErr := app.buildUsage(promptTokens, msg)
|
||||
if tokenErr != nil {
|
||||
s.log.Error(tokenErr)
|
||||
} else {
|
||||
final.Usage = usage
|
||||
}
|
||||
}
|
||||
return stream.Send(final)
|
||||
}
|
||||
|
||||
keywords := app.keywords(in)
|
||||
@@ -125,7 +174,7 @@ func (s *chatService) ChatCompletionStream(in *proto.ChatCompletionRequest, stre
|
||||
s.busMetrics.KeywordsQuestionsTotalCounter.Inc()
|
||||
}
|
||||
|
||||
req, _, _, _, err := app.buildChatCompletionRequest(in, true)
|
||||
req, _, currTokens, _, err := app.buildChatCompletionRequest(in, true)
|
||||
if err != nil {
|
||||
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||
return err
|
||||
@@ -133,7 +182,9 @@ func (s *chatService) ChatCompletionStream(in *proto.ChatCompletionRequest, stre
|
||||
|
||||
questionEmbedding, cachedRecord := s.searchCachedAnswer(stream.Context(), in.Message)
|
||||
if cachedRecord != nil {
|
||||
if err = stream.Send(app.buildChatCompletionStreamResponse(cachedRecord.ID, "", "")); err != nil {
|
||||
start := app.buildChatCompletionStreamResponse(cachedRecord.ID, "", "")
|
||||
start.Source = replySourceSemanticMatch
|
||||
if err = stream.Send(start); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, res := range app.buildChatCompletionStreamResponseList(cachedRecord.ID, cachedRecord.Answer) {
|
||||
@@ -141,7 +192,15 @@ func (s *chatService) ChatCompletionStream(in *proto.ChatCompletionRequest, stre
|
||||
return err
|
||||
}
|
||||
}
|
||||
return stream.Send(app.buildChatCompletionStreamResponse(cachedRecord.ID, "", "stop"))
|
||||
final := app.buildChatCompletionStreamResponse(cachedRecord.ID, "", "stop")
|
||||
final.Source = replySourceSemanticMatch
|
||||
usage, tokenErr := app.buildUsage(currTokens, cachedRecord.Answer)
|
||||
if tokenErr != nil {
|
||||
s.log.Error(tokenErr)
|
||||
} else {
|
||||
final.Usage = usage
|
||||
}
|
||||
return stream.Send(final)
|
||||
}
|
||||
|
||||
client := app.getOpenaiClient()
|
||||
@@ -154,6 +213,7 @@ func (s *chatService) ChatCompletionStream(in *proto.ChatCompletionRequest, stre
|
||||
defer chatStream.Close()
|
||||
|
||||
completionContent := ""
|
||||
responseID := ""
|
||||
for {
|
||||
resp, err := chatStream.Recv()
|
||||
if err != nil && err != io.EOF {
|
||||
@@ -164,6 +224,9 @@ func (s *chatService) ChatCompletionStream(in *proto.ChatCompletionRequest, stre
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if resp.ID != "" {
|
||||
responseID = resp.ID
|
||||
}
|
||||
completionContent += resp.Choices[0].Delta.Content
|
||||
|
||||
res := &proto.ChatCompletionStreamResponse{}
|
||||
@@ -176,21 +239,24 @@ func (s *chatService) ChatCompletionStream(in *proto.ChatCompletionRequest, stre
|
||||
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||
return err
|
||||
}
|
||||
res.Source = replySourceLLM
|
||||
if err = stream.Send(res); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
model := s.config.Chat.Model
|
||||
if in.ChatParam != nil && in.ChatParam.Model != "" {
|
||||
model = in.ChatParam.Model
|
||||
}
|
||||
resultMessage := openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
Content: completionContent,
|
||||
}
|
||||
if _, err = tokenizer.GetTokens(&resultMessage, model); err != nil {
|
||||
usage, tokenErr := app.buildUsage(currTokens, completionContent)
|
||||
if tokenErr != nil {
|
||||
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||
return tokenErr
|
||||
}
|
||||
if responseID == "" {
|
||||
responseID = uuid.New().String()
|
||||
}
|
||||
final := app.buildChatCompletionStreamResponse(responseID, "", "stop")
|
||||
final.Usage = usage
|
||||
final.Source = replySourceLLM
|
||||
if err = stream.Send(final); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user