296 lines
7.9 KiB
Go
296 lines
7.9 KiB
Go
package controllers
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/Arvintian/chatgpt-web/pkg/tokenizer"
|
|
"github.com/Arvintian/chatgpt-web/pkg/utils"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/google/uuid"
|
|
ccache "github.com/karlseguin/ccache/v3"
|
|
openai "github.com/sashabaranov/go-openai"
|
|
"k8s.io/klog/v2"
|
|
)
|
|
|
|
const (
|
|
ChatPrimedTokens = 2
|
|
KimiTopP = 0.95
|
|
)
|
|
|
|
type ChatService struct {
|
|
client *openai.Client
|
|
store *ccache.Cache[ChatMessage]
|
|
params ChatCompletionParams
|
|
}
|
|
|
|
type ChatCompletionParams struct {
|
|
Model string `json:"model"`
|
|
MaxTokens int `json:"max_tokens,omitempty"`
|
|
Temperature float32 `json:"temperature,omitempty"`
|
|
PresencePenalty float32 `json:"presence_penalty,omitempty"`
|
|
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
|
|
ChatSessionTTL time.Duration `json:"chat_session_ttl"`
|
|
ChatMinResponseTokens int `json:"chat_min_response_tokens"`
|
|
}
|
|
|
|
type ChatMessageRequest struct {
|
|
Prompt string `json:"prompt"`
|
|
Options ChatMessageRequestOptions `json:"options"`
|
|
}
|
|
|
|
type ChatMessageRequestOptions struct {
|
|
Name string `json:"name"`
|
|
ParentMessageId string `json:"parentMessageId"`
|
|
}
|
|
|
|
type ChatMessage struct {
|
|
ID string `json:"id"`
|
|
Text string `json:"text"`
|
|
Role string `json:"role"`
|
|
Name string `json:"name"`
|
|
Delta string `json:"delta"`
|
|
Detail openai.ChatCompletionStreamResponse `json:"detail"`
|
|
TokenCount int `json:"tokenCount"`
|
|
ParentMessageId string `json:"parentMessageId"`
|
|
}
|
|
|
|
func NewChatService(apiKey string, baseURL string, socksProxy string, params ChatCompletionParams) (*ChatService, error) {
|
|
config := openai.DefaultConfig(apiKey)
|
|
if baseURL != "" {
|
|
config.BaseURL = baseURL
|
|
}
|
|
klog.Infof("use openai base url: %s", config.BaseURL)
|
|
if socksProxy != "" {
|
|
proxyUrl, err := url.Parse(socksProxy) //socks5://user:password@127.0.0.1:1080
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
config.HTTPClient = &http.Client{
|
|
Transport: &http.Transport{
|
|
Proxy: http.ProxyURL(proxyUrl),
|
|
},
|
|
}
|
|
klog.Infof("use sock proxy: %s", proxyUrl)
|
|
}
|
|
chat := ChatService{
|
|
client: openai.NewClientWithConfig(config),
|
|
params: params,
|
|
store: ccache.New(ccache.Configure[ChatMessage]()),
|
|
}
|
|
return &chat, nil
|
|
}
|
|
|
|
func (chat *ChatService) ChatProcess(ctx *gin.Context) {
|
|
payload := ChatMessageRequest{}
|
|
if err := ctx.BindJSON(&payload); err != nil {
|
|
klog.Error(err)
|
|
ctx.JSON(200, gin.H{
|
|
"status": "Fail",
|
|
"message": fmt.Sprintf("%v", err),
|
|
"data": nil,
|
|
})
|
|
return
|
|
}
|
|
|
|
messageID := uuid.New().String()
|
|
|
|
message := ChatMessage{
|
|
ID: messageID,
|
|
Role: openai.ChatMessageRoleUser,
|
|
Text: payload.Prompt,
|
|
ParentMessageId: payload.Options.ParentMessageId,
|
|
}
|
|
|
|
result := ChatMessage{
|
|
ID: uuid.New().String(),
|
|
Role: openai.ChatMessageRoleAssistant,
|
|
Text: "",
|
|
ParentMessageId: messageID,
|
|
}
|
|
|
|
messages, numTokens, tokenCount, err := chat.buildMessage(payload)
|
|
if err != nil {
|
|
ctx.JSON(200, gin.H{
|
|
"status": "Fail",
|
|
"message": fmt.Sprintf("%v", err),
|
|
"data": nil,
|
|
})
|
|
return
|
|
}
|
|
|
|
message.TokenCount = tokenCount
|
|
chat.store.Set(messageID, message, chat.params.ChatSessionTTL)
|
|
|
|
//klog.Infof("send message %d tokens, set completion %d max tokens", numTokens, chat.params.MaxTokens-numTokens)
|
|
|
|
stream, err := chat.client.CreateChatCompletionStream(ctx, openai.ChatCompletionRequest{
|
|
Model: chat.params.Model,
|
|
Messages: messages,
|
|
MaxTokens: chat.params.MaxTokens - numTokens,
|
|
Temperature: chat.params.Temperature,
|
|
PresencePenalty: chat.params.PresencePenalty,
|
|
FrequencyPenalty: chat.params.FrequencyPenalty,
|
|
TopP: chat.topP(),
|
|
Stream: true,
|
|
})
|
|
if err != nil {
|
|
klog.Error(err)
|
|
ctx.JSON(200, gin.H{
|
|
"status": "Fail",
|
|
"message": fmt.Sprintf("%v", err),
|
|
"data": nil,
|
|
})
|
|
return
|
|
}
|
|
defer stream.Close()
|
|
|
|
resp := stream.GetResponse()
|
|
if resp.StatusCode != 200 {
|
|
bts, _ := io.ReadAll(resp.Body)
|
|
ctx.JSON(200, gin.H{
|
|
"status": "Fail",
|
|
"message": fmt.Sprintf("%v", string(bts)),
|
|
"data": nil,
|
|
})
|
|
return
|
|
}
|
|
|
|
firstChunk := true
|
|
ctx.Header("Content-type", "application/octet-stream")
|
|
for {
|
|
rsp, err := stream.Recv()
|
|
if errors.Is(err, io.EOF) {
|
|
go func() {
|
|
tokenCount, err := tokenizer.GetTokenCount(openai.ChatCompletionMessage{
|
|
Role: result.Role,
|
|
Content: result.Text,
|
|
Name: result.Name,
|
|
}, chat.params.Model)
|
|
if err != nil {
|
|
klog.Error(err)
|
|
}
|
|
result.TokenCount = tokenCount
|
|
chat.store.Set(result.ID, result, chat.params.ChatSessionTTL)
|
|
}()
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
klog.Error(err)
|
|
ctx.JSON(200, gin.H{
|
|
"status": "Fail",
|
|
"message": fmt.Sprintf("OpenAI Event Error %v", err),
|
|
"data": nil,
|
|
})
|
|
return
|
|
}
|
|
|
|
if rsp.ID != "" {
|
|
result.ID = rsp.ID
|
|
}
|
|
|
|
if len(rsp.Choices) > 0 {
|
|
content := rsp.Choices[0].Delta.Content
|
|
result.Delta = content
|
|
if len(content) > 0 {
|
|
result.Text += content
|
|
}
|
|
result.Detail = rsp
|
|
}
|
|
|
|
bts, err := json.Marshal(result)
|
|
if err != nil {
|
|
klog.Error(err)
|
|
ctx.JSON(200, gin.H{
|
|
"status": "Fail",
|
|
"message": fmt.Sprintf("OpenAI Event Marshal Error %v", err),
|
|
"data": nil,
|
|
})
|
|
return
|
|
}
|
|
|
|
if !firstChunk {
|
|
ctx.Writer.Write([]byte("\n"))
|
|
} else {
|
|
firstChunk = false
|
|
}
|
|
|
|
if _, err := ctx.Writer.Write(bts); err != nil {
|
|
klog.Error(err)
|
|
return
|
|
}
|
|
|
|
ctx.Writer.Flush()
|
|
}
|
|
}
|
|
|
|
func (chat *ChatService) topP() float32 {
|
|
if strings.HasPrefix(chat.params.Model, "kimi-") || strings.HasPrefix(chat.params.Model, "moonshot-") {
|
|
return KimiTopP
|
|
}
|
|
return 1
|
|
}
|
|
|
|
func (chat *ChatService) buildMessage(payload ChatMessageRequest) ([]openai.ChatCompletionMessage, int, int, error) {
|
|
parentMessageId := payload.Options.ParentMessageId
|
|
messages := []openai.ChatCompletionMessage{}
|
|
tokenCount := 0
|
|
var err error
|
|
if len(payload.Prompt) > 0 {
|
|
chatMessage := openai.ChatCompletionMessage{
|
|
Role: openai.ChatMessageRoleUser,
|
|
Content: payload.Prompt,
|
|
Name: payload.Options.Name,
|
|
}
|
|
messages = append(messages, chatMessage)
|
|
tokenCount, err = tokenizer.GetTokenCount(chatMessage, chat.params.Model)
|
|
if err != nil {
|
|
return nil, 0, 0, err
|
|
}
|
|
if tokenCount >= (chat.params.MaxTokens - chat.params.ChatMinResponseTokens) {
|
|
return nil, 0, 0, fmt.Errorf("this model's maximum context length is %d tokens. you requested %d tokens in the messages", chat.params.MaxTokens, tokenCount)
|
|
}
|
|
}
|
|
numTokens := tokenCount + ChatPrimedTokens
|
|
for {
|
|
if parentMessageId == "" {
|
|
break
|
|
}
|
|
parentMessage, ok := chat.getMessageByID(parentMessageId)
|
|
if !ok {
|
|
break
|
|
}
|
|
parentCompletioMessage := openai.ChatCompletionMessage{
|
|
Role: parentMessage.Role,
|
|
Content: parentMessage.Text,
|
|
Name: parentMessage.Name,
|
|
}
|
|
if (numTokens + parentMessage.TokenCount) >= (chat.params.MaxTokens - chat.params.ChatMinResponseTokens) {
|
|
break
|
|
}
|
|
numTokens += parentMessage.TokenCount
|
|
messages = append(messages, parentCompletioMessage)
|
|
parentMessageId = parentMessage.ParentMessageId
|
|
}
|
|
utils.Reverse(messages)
|
|
return messages, numTokens, tokenCount, nil
|
|
}
|
|
|
|
func (chat *ChatService) getMessageByID(id string) (ChatMessage, bool) {
|
|
item := chat.store.Get(id)
|
|
if item == nil {
|
|
return ChatMessage{}, false
|
|
}
|
|
if item.Expired() {
|
|
return ChatMessage{}, false
|
|
}
|
|
return item.Value(), true
|
|
}
|