215 lines
6.2 KiB
Go
215 lines
6.2 KiB
Go
package server
|
|
|
|
import (
|
|
"ai-chat-service/pkg/config"
|
|
"ai-chat-service/pkg/log"
|
|
"ai-chat-service/pkg/zerror"
|
|
"ai-chat-service/proto"
|
|
"ai-chat-service/services"
|
|
keywords_filter "ai-chat-service/services/keywords-filter"
|
|
keywords_proto "ai-chat-service/services/keywords-filter/proto"
|
|
"ai-chat-service/services/tokenizer"
|
|
"context"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/sashabaranov/go-openai"
|
|
)
|
|
|
|
const ChatPrimedTokens = 2
|
|
|
|
type openaiConf struct {
|
|
ApiKey string
|
|
BaseUrl string
|
|
Model string
|
|
MaxTokens int
|
|
Temperature float32
|
|
TopP float32
|
|
PresencePenalty float32
|
|
FrequencyPenalty float32
|
|
BotDesc string
|
|
MinResponseTokens int
|
|
}
|
|
|
|
type app struct {
|
|
openaiConf *openaiConf
|
|
log log.ILogger
|
|
}
|
|
|
|
func (s *chatService) newApp(in *proto.ChatCompletionRequest) *app {
|
|
conf := &openaiConf{
|
|
ApiKey: s.config.Chat.ApiKey,
|
|
BaseUrl: s.config.Chat.BaseUrl,
|
|
Model: s.config.Chat.Model,
|
|
MaxTokens: s.config.Chat.MaxTokens,
|
|
Temperature: s.config.Chat.Temperature,
|
|
TopP: s.config.Chat.TopP,
|
|
PresencePenalty: s.config.Chat.PresencePenalty,
|
|
FrequencyPenalty: s.config.Chat.FrequencyPenalty,
|
|
BotDesc: s.config.Chat.BotDesc,
|
|
MinResponseTokens: s.config.Chat.MinResponseTokens,
|
|
}
|
|
if in.ChatParam != nil {
|
|
if in.ChatParam.Model != "" {
|
|
conf.Model = in.ChatParam.Model
|
|
}
|
|
conf.TopP = in.ChatParam.TopP
|
|
conf.FrequencyPenalty = in.ChatParam.FrequencyPenalty
|
|
conf.PresencePenalty = in.ChatParam.PresencePenalty
|
|
conf.Temperature = in.ChatParam.Temperature
|
|
if in.ChatParam.BotDesc != "" {
|
|
conf.BotDesc = in.ChatParam.BotDesc
|
|
}
|
|
if in.ChatParam.MaxTokens != 0 {
|
|
conf.MaxTokens = int(in.ChatParam.MaxTokens)
|
|
}
|
|
if in.ChatParam.MinResponseTokens != 0 {
|
|
conf.MinResponseTokens = int(in.ChatParam.MinResponseTokens)
|
|
}
|
|
}
|
|
return &app{
|
|
openaiConf: conf,
|
|
log: s.log,
|
|
}
|
|
}
|
|
|
|
func (a *app) getOpenaiClient() *openai.Client {
|
|
conf := openai.DefaultConfig(a.openaiConf.ApiKey)
|
|
conf.BaseURL = a.openaiConf.BaseUrl
|
|
return openai.NewClientWithConfig(conf)
|
|
}
|
|
|
|
func (a *app) buildChatCompletionRequest(in *proto.ChatCompletionRequest, stream bool) (req openai.ChatCompletionRequest, tokens, currTokens int, currMessage openai.ChatCompletionMessage, err error) {
|
|
currMessage = openai.ChatCompletionMessage{
|
|
Role: openai.ChatMessageRoleUser,
|
|
Content: in.Message,
|
|
}
|
|
req = openai.ChatCompletionRequest{
|
|
Model: a.openaiConf.Model,
|
|
MaxTokens: a.openaiConf.MinResponseTokens,
|
|
Temperature: a.openaiConf.Temperature,
|
|
TopP: a.openaiConf.TopP,
|
|
PresencePenalty: a.openaiConf.PresencePenalty,
|
|
FrequencyPenalty: a.openaiConf.FrequencyPenalty,
|
|
Stream: stream,
|
|
}
|
|
tokens, currTokens, req.Messages, err = a.rebuildMessages(currMessage)
|
|
if err != nil {
|
|
a.log.Error(err)
|
|
return
|
|
}
|
|
req.MaxTokens = a.openaiConf.MaxTokens - tokens
|
|
return
|
|
}
|
|
|
|
func (a *app) rebuildMessages(currMessage openai.ChatCompletionMessage) (tokens, currTokens int, messages []openai.ChatCompletionMessage, err error) {
|
|
messages = make([]openai.ChatCompletionMessage, 0, 2)
|
|
botTokens := 0
|
|
if a.openaiConf.BotDesc != "" {
|
|
sysMessage := openai.ChatCompletionMessage{
|
|
Role: openai.ChatMessageRoleSystem,
|
|
Content: a.openaiConf.BotDesc,
|
|
}
|
|
botTokens, err = tokenizer.GetTokens(&sysMessage, a.openaiConf.Model)
|
|
if err != nil {
|
|
return
|
|
}
|
|
messages = append(messages, sysMessage)
|
|
}
|
|
|
|
currTokens, err = tokenizer.GetTokens(&currMessage, a.openaiConf.Model)
|
|
if err != nil {
|
|
return
|
|
}
|
|
if currTokens > a.openaiConf.MaxTokens-a.openaiConf.MinResponseTokens-botTokens-ChatPrimedTokens {
|
|
return 0, 0, nil, zerror.NewByMsg("请求消息超限")
|
|
}
|
|
|
|
tokens = currTokens + botTokens + ChatPrimedTokens
|
|
messages = append(messages, currMessage)
|
|
return
|
|
}
|
|
|
|
func (a *app) buildChatCompletionResponse(msg string) *proto.ChatCompletionResponse {
|
|
return &proto.ChatCompletionResponse{
|
|
Id: uuid.New().String(),
|
|
Object: "chat.completion",
|
|
Created: time.Now().Unix(),
|
|
Model: a.openaiConf.Model,
|
|
Choices: []*proto.ChatCompletionChoice{
|
|
{
|
|
Message: &proto.ChatCompletionMessage{
|
|
Role: openai.ChatMessageRoleAssistant,
|
|
Content: msg,
|
|
},
|
|
FinishReason: "stop",
|
|
},
|
|
},
|
|
Usage: &proto.Usage{},
|
|
}
|
|
}
|
|
|
|
func (a *app) buildChatCompletionStreamResponse(id, delta, finishReason string) *proto.ChatCompletionStreamResponse {
|
|
return &proto.ChatCompletionStreamResponse{
|
|
Id: id,
|
|
Object: "chat.completion.chunk",
|
|
Created: time.Now().Unix(),
|
|
Model: a.openaiConf.Model,
|
|
Choices: []*proto.ChatCompletionStreamChoice{
|
|
{
|
|
Index: 0,
|
|
Delta: &proto.ChatCompletionStreamChoiceDelta{
|
|
Content: delta,
|
|
Role: openai.ChatMessageRoleAssistant,
|
|
},
|
|
FinishReason: finishReason,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (a *app) buildChatCompletionStreamResponseList(id, msg string) []*proto.ChatCompletionStreamResponse {
|
|
list := make([]*proto.ChatCompletionStreamResponse, 0, len(msg))
|
|
for _, delta := range msg {
|
|
list = append(list, a.buildChatCompletionStreamResponse(id, string(delta), ""))
|
|
}
|
|
return list
|
|
}
|
|
|
|
func (a *app) keywords(in *proto.ChatCompletionRequest) []string {
|
|
pool := keywords_filter.GetKeywordsClientPool()
|
|
conn := pool.Get()
|
|
defer pool.Put(conn)
|
|
|
|
accessToken := config.GetConfig().DependOn.Keywords.AccessToken
|
|
client := keywords_proto.NewFilterClient(conn)
|
|
ctx := services.AppendBearerTokenToContext(context.Background(), accessToken)
|
|
req := &keywords_proto.FilterReq{Text: in.Message}
|
|
res, err := client.FindAll(ctx, req)
|
|
if err != nil {
|
|
a.log.Error(err)
|
|
return []string{}
|
|
}
|
|
return res.Keywords
|
|
}
|
|
|
|
func (a *app) sensitive(in *proto.ChatCompletionRequest) (ok bool, msg string, err error) {
|
|
pool := keywords_filter.GetSensitiveClientPool()
|
|
conn := pool.Get()
|
|
defer pool.Put(conn)
|
|
|
|
accessToken := config.GetConfig().DependOn.Sensitive.AccessToken
|
|
client := keywords_proto.NewFilterClient(conn)
|
|
ctx := services.AppendBearerTokenToContext(context.Background(), accessToken)
|
|
req := &keywords_proto.FilterReq{Text: in.Message}
|
|
res, err := client.Validate(ctx, req)
|
|
if err != nil {
|
|
a.log.Error(err)
|
|
return false, "", err
|
|
}
|
|
if !res.Ok {
|
|
return false, "触发到了知识盲区,请换个问题再问", nil
|
|
}
|
|
return true, "", nil
|
|
}
|