package server import ( "ai-chat-service/pkg/config" "ai-chat-service/pkg/log" "ai-chat-service/pkg/zerror" "ai-chat-service/proto" "ai-chat-service/services" keywords_filter "ai-chat-service/services/keywords-filter" keywords_proto "ai-chat-service/services/keywords-filter/proto" "ai-chat-service/services/tokenizer" "context" "time" "github.com/google/uuid" "github.com/sashabaranov/go-openai" ) const ChatPrimedTokens = 2 type openaiConf struct { ApiKey string BaseUrl string Model string MaxTokens int Temperature float32 TopP float32 PresencePenalty float32 FrequencyPenalty float32 BotDesc string MinResponseTokens int } type app struct { openaiConf *openaiConf log log.ILogger } func (s *chatService) newApp(in *proto.ChatCompletionRequest) *app { conf := &openaiConf{ ApiKey: s.config.Chat.ApiKey, BaseUrl: s.config.Chat.BaseUrl, Model: s.config.Chat.Model, MaxTokens: s.config.Chat.MaxTokens, Temperature: s.config.Chat.Temperature, TopP: s.config.Chat.TopP, PresencePenalty: s.config.Chat.PresencePenalty, FrequencyPenalty: s.config.Chat.FrequencyPenalty, BotDesc: s.config.Chat.BotDesc, MinResponseTokens: s.config.Chat.MinResponseTokens, } if in.ChatParam != nil { if in.ChatParam.Model != "" { conf.Model = in.ChatParam.Model } conf.TopP = in.ChatParam.TopP conf.FrequencyPenalty = in.ChatParam.FrequencyPenalty conf.PresencePenalty = in.ChatParam.PresencePenalty conf.Temperature = in.ChatParam.Temperature if in.ChatParam.BotDesc != "" { conf.BotDesc = in.ChatParam.BotDesc } if in.ChatParam.MaxTokens != 0 { conf.MaxTokens = int(in.ChatParam.MaxTokens) } if in.ChatParam.MinResponseTokens != 0 { conf.MinResponseTokens = int(in.ChatParam.MinResponseTokens) } } return &app{ openaiConf: conf, log: s.log, } } func (a *app) getOpenaiClient() *openai.Client { conf := openai.DefaultConfig(a.openaiConf.ApiKey) conf.BaseURL = a.openaiConf.BaseUrl return openai.NewClientWithConfig(conf) } func (a *app) buildChatCompletionRequest(in *proto.ChatCompletionRequest, stream bool) (req openai.ChatCompletionRequest, tokens, currTokens int, currMessage openai.ChatCompletionMessage, err error) { currMessage = openai.ChatCompletionMessage{ Role: openai.ChatMessageRoleUser, Content: in.Message, } req = openai.ChatCompletionRequest{ Model: a.openaiConf.Model, MaxTokens: a.openaiConf.MinResponseTokens, Temperature: a.openaiConf.Temperature, TopP: a.openaiConf.TopP, PresencePenalty: a.openaiConf.PresencePenalty, FrequencyPenalty: a.openaiConf.FrequencyPenalty, Stream: stream, } tokens, currTokens, req.Messages, err = a.rebuildMessages(currMessage) if err != nil { a.log.Error(err) return } req.MaxTokens = a.openaiConf.MaxTokens - tokens return } func (a *app) rebuildMessages(currMessage openai.ChatCompletionMessage) (tokens, currTokens int, messages []openai.ChatCompletionMessage, err error) { messages = make([]openai.ChatCompletionMessage, 0, 2) botTokens := 0 if a.openaiConf.BotDesc != "" { sysMessage := openai.ChatCompletionMessage{ Role: openai.ChatMessageRoleSystem, Content: a.openaiConf.BotDesc, } botTokens, err = tokenizer.GetTokens(&sysMessage, a.openaiConf.Model) if err != nil { return } messages = append(messages, sysMessage) } currTokens, err = tokenizer.GetTokens(&currMessage, a.openaiConf.Model) if err != nil { return } if currTokens > a.openaiConf.MaxTokens-a.openaiConf.MinResponseTokens-botTokens-ChatPrimedTokens { return 0, 0, nil, zerror.NewByMsg("请求消息超限") } tokens = currTokens + botTokens + ChatPrimedTokens messages = append(messages, currMessage) return } func (a *app) buildChatCompletionResponse(msg string) *proto.ChatCompletionResponse { return &proto.ChatCompletionResponse{ Id: uuid.New().String(), Object: "chat.completion", Created: time.Now().Unix(), Model: a.openaiConf.Model, Choices: []*proto.ChatCompletionChoice{ { Message: &proto.ChatCompletionMessage{ Role: openai.ChatMessageRoleAssistant, Content: msg, }, FinishReason: "stop", }, }, Usage: &proto.Usage{}, } } func (a *app) countMessageTokens(role, content string) (int, error) { message := openai.ChatCompletionMessage{ Role: role, Content: content, } return tokenizer.GetTokens(&message, a.openaiConf.Model) } func (a *app) buildZeroUsage() *proto.Usage { return &proto.Usage{} } func (a *app) buildUsage(promptTokens int, answer string) (*proto.Usage, error) { completionTokens := 0 if answer != "" { tokens, err := a.countMessageTokens(openai.ChatMessageRoleAssistant, answer) if err != nil { return nil, err } completionTokens = tokens } return &proto.Usage{ PromptTokens: int32(promptTokens), CompletionTokens: int32(completionTokens), TotalTokens: int32(promptTokens + completionTokens), }, nil } func (a *app) buildChatCompletionStreamResponse(id, delta, finishReason string) *proto.ChatCompletionStreamResponse { return &proto.ChatCompletionStreamResponse{ Id: id, Object: "chat.completion.chunk", Created: time.Now().Unix(), Model: a.openaiConf.Model, Choices: []*proto.ChatCompletionStreamChoice{ { Index: 0, Delta: &proto.ChatCompletionStreamChoiceDelta{ Content: delta, Role: openai.ChatMessageRoleAssistant, }, FinishReason: finishReason, }, }, } } func (a *app) buildChatCompletionStreamResponseList(id, msg string) []*proto.ChatCompletionStreamResponse { list := make([]*proto.ChatCompletionStreamResponse, 0, len(msg)) for _, delta := range msg { list = append(list, a.buildChatCompletionStreamResponse(id, string(delta), "")) } return list } func (a *app) keywords(in *proto.ChatCompletionRequest) []string { pool := keywords_filter.GetKeywordsClientPool() conn := pool.Get() defer pool.Put(conn) accessToken := config.GetConfig().DependOn.Keywords.AccessToken client := keywords_proto.NewFilterClient(conn) ctx := services.AppendBearerTokenToContext(context.Background(), accessToken) req := &keywords_proto.FilterReq{Text: in.Message} res, err := client.FindAll(ctx, req) if err != nil { a.log.Error(err) return []string{} } return res.Keywords } func (a *app) sensitive(in *proto.ChatCompletionRequest) (ok bool, msg string, err error) { pool := keywords_filter.GetSensitiveClientPool() conn := pool.Get() defer pool.Put(conn) accessToken := config.GetConfig().DependOn.Sensitive.AccessToken client := keywords_proto.NewFilterClient(conn) ctx := services.AppendBearerTokenToContext(context.Background(), accessToken) req := &keywords_proto.FilterReq{Text: in.Message} res, err := client.Validate(ctx, req) if err != nil { a.log.Error(err) return false, "", err } if !res.Ok { return false, "触发到了知识盲区,请换个问题再问", nil } return true, "", nil }