first commit
This commit is contained in:
295
chatgpt-web-backend/pkg/controllers/chat.go
Normal file
295
chatgpt-web-backend/pkg/controllers/chat.go
Normal file
@@ -0,0 +1,295 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Arvintian/chatgpt-web/pkg/tokenizer"
|
||||
"github.com/Arvintian/chatgpt-web/pkg/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
ccache "github.com/karlseguin/ccache/v3"
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
"k8s.io/klog/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
ChatPrimedTokens = 2
|
||||
KimiTopP = 0.95
|
||||
)
|
||||
|
||||
type ChatService struct {
|
||||
client *openai.Client
|
||||
store *ccache.Cache[ChatMessage]
|
||||
params ChatCompletionParams
|
||||
}
|
||||
|
||||
type ChatCompletionParams struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
PresencePenalty float32 `json:"presence_penalty,omitempty"`
|
||||
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
|
||||
ChatSessionTTL time.Duration `json:"chat_session_ttl"`
|
||||
ChatMinResponseTokens int `json:"chat_min_response_tokens"`
|
||||
}
|
||||
|
||||
type ChatMessageRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Options ChatMessageRequestOptions `json:"options"`
|
||||
}
|
||||
|
||||
type ChatMessageRequestOptions struct {
|
||||
Name string `json:"name"`
|
||||
ParentMessageId string `json:"parentMessageId"`
|
||||
}
|
||||
|
||||
type ChatMessage struct {
|
||||
ID string `json:"id"`
|
||||
Text string `json:"text"`
|
||||
Role string `json:"role"`
|
||||
Name string `json:"name"`
|
||||
Delta string `json:"delta"`
|
||||
Detail openai.ChatCompletionStreamResponse `json:"detail"`
|
||||
TokenCount int `json:"tokenCount"`
|
||||
ParentMessageId string `json:"parentMessageId"`
|
||||
}
|
||||
|
||||
func NewChatService(apiKey string, baseURL string, socksProxy string, params ChatCompletionParams) (*ChatService, error) {
|
||||
config := openai.DefaultConfig(apiKey)
|
||||
if baseURL != "" {
|
||||
config.BaseURL = baseURL
|
||||
}
|
||||
klog.Infof("use openai base url: %s", config.BaseURL)
|
||||
if socksProxy != "" {
|
||||
proxyUrl, err := url.Parse(socksProxy) //socks5://user:password@127.0.0.1:1080
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.HTTPClient = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyUrl),
|
||||
},
|
||||
}
|
||||
klog.Infof("use sock proxy: %s", proxyUrl)
|
||||
}
|
||||
chat := ChatService{
|
||||
client: openai.NewClientWithConfig(config),
|
||||
params: params,
|
||||
store: ccache.New(ccache.Configure[ChatMessage]()),
|
||||
}
|
||||
return &chat, nil
|
||||
}
|
||||
|
||||
func (chat *ChatService) ChatProcess(ctx *gin.Context) {
|
||||
payload := ChatMessageRequest{}
|
||||
if err := ctx.BindJSON(&payload); err != nil {
|
||||
klog.Error(err)
|
||||
ctx.JSON(200, gin.H{
|
||||
"status": "Fail",
|
||||
"message": fmt.Sprintf("%v", err),
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
messageID := uuid.New().String()
|
||||
|
||||
message := ChatMessage{
|
||||
ID: messageID,
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Text: payload.Prompt,
|
||||
ParentMessageId: payload.Options.ParentMessageId,
|
||||
}
|
||||
|
||||
result := ChatMessage{
|
||||
ID: uuid.New().String(),
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
Text: "",
|
||||
ParentMessageId: messageID,
|
||||
}
|
||||
|
||||
messages, numTokens, tokenCount, err := chat.buildMessage(payload)
|
||||
if err != nil {
|
||||
ctx.JSON(200, gin.H{
|
||||
"status": "Fail",
|
||||
"message": fmt.Sprintf("%v", err),
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
message.TokenCount = tokenCount
|
||||
chat.store.Set(messageID, message, chat.params.ChatSessionTTL)
|
||||
|
||||
//klog.Infof("send message %d tokens, set completion %d max tokens", numTokens, chat.params.MaxTokens-numTokens)
|
||||
|
||||
stream, err := chat.client.CreateChatCompletionStream(ctx, openai.ChatCompletionRequest{
|
||||
Model: chat.params.Model,
|
||||
Messages: messages,
|
||||
MaxTokens: chat.params.MaxTokens - numTokens,
|
||||
Temperature: chat.params.Temperature,
|
||||
PresencePenalty: chat.params.PresencePenalty,
|
||||
FrequencyPenalty: chat.params.FrequencyPenalty,
|
||||
TopP: chat.topP(),
|
||||
Stream: true,
|
||||
})
|
||||
if err != nil {
|
||||
klog.Error(err)
|
||||
ctx.JSON(200, gin.H{
|
||||
"status": "Fail",
|
||||
"message": fmt.Sprintf("%v", err),
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
resp := stream.GetResponse()
|
||||
if resp.StatusCode != 200 {
|
||||
bts, _ := io.ReadAll(resp.Body)
|
||||
ctx.JSON(200, gin.H{
|
||||
"status": "Fail",
|
||||
"message": fmt.Sprintf("%v", string(bts)),
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
firstChunk := true
|
||||
ctx.Header("Content-type", "application/octet-stream")
|
||||
for {
|
||||
rsp, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
go func() {
|
||||
tokenCount, err := tokenizer.GetTokenCount(openai.ChatCompletionMessage{
|
||||
Role: result.Role,
|
||||
Content: result.Text,
|
||||
Name: result.Name,
|
||||
}, chat.params.Model)
|
||||
if err != nil {
|
||||
klog.Error(err)
|
||||
}
|
||||
result.TokenCount = tokenCount
|
||||
chat.store.Set(result.ID, result, chat.params.ChatSessionTTL)
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
klog.Error(err)
|
||||
ctx.JSON(200, gin.H{
|
||||
"status": "Fail",
|
||||
"message": fmt.Sprintf("OpenAI Event Error %v", err),
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if rsp.ID != "" {
|
||||
result.ID = rsp.ID
|
||||
}
|
||||
|
||||
if len(rsp.Choices) > 0 {
|
||||
content := rsp.Choices[0].Delta.Content
|
||||
result.Delta = content
|
||||
if len(content) > 0 {
|
||||
result.Text += content
|
||||
}
|
||||
result.Detail = rsp
|
||||
}
|
||||
|
||||
bts, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
klog.Error(err)
|
||||
ctx.JSON(200, gin.H{
|
||||
"status": "Fail",
|
||||
"message": fmt.Sprintf("OpenAI Event Marshal Error %v", err),
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if !firstChunk {
|
||||
ctx.Writer.Write([]byte("\n"))
|
||||
} else {
|
||||
firstChunk = false
|
||||
}
|
||||
|
||||
if _, err := ctx.Writer.Write(bts); err != nil {
|
||||
klog.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Writer.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (chat *ChatService) topP() float32 {
|
||||
if strings.HasPrefix(chat.params.Model, "kimi-") || strings.HasPrefix(chat.params.Model, "moonshot-") {
|
||||
return KimiTopP
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func (chat *ChatService) buildMessage(payload ChatMessageRequest) ([]openai.ChatCompletionMessage, int, int, error) {
|
||||
parentMessageId := payload.Options.ParentMessageId
|
||||
messages := []openai.ChatCompletionMessage{}
|
||||
tokenCount := 0
|
||||
var err error
|
||||
if len(payload.Prompt) > 0 {
|
||||
chatMessage := openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: payload.Prompt,
|
||||
Name: payload.Options.Name,
|
||||
}
|
||||
messages = append(messages, chatMessage)
|
||||
tokenCount, err = tokenizer.GetTokenCount(chatMessage, chat.params.Model)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
if tokenCount >= (chat.params.MaxTokens - chat.params.ChatMinResponseTokens) {
|
||||
return nil, 0, 0, fmt.Errorf("this model's maximum context length is %d tokens. you requested %d tokens in the messages", chat.params.MaxTokens, tokenCount)
|
||||
}
|
||||
}
|
||||
numTokens := tokenCount + ChatPrimedTokens
|
||||
for {
|
||||
if parentMessageId == "" {
|
||||
break
|
||||
}
|
||||
parentMessage, ok := chat.getMessageByID(parentMessageId)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
parentCompletioMessage := openai.ChatCompletionMessage{
|
||||
Role: parentMessage.Role,
|
||||
Content: parentMessage.Text,
|
||||
Name: parentMessage.Name,
|
||||
}
|
||||
if (numTokens + parentMessage.TokenCount) >= (chat.params.MaxTokens - chat.params.ChatMinResponseTokens) {
|
||||
break
|
||||
}
|
||||
numTokens += parentMessage.TokenCount
|
||||
messages = append(messages, parentCompletioMessage)
|
||||
parentMessageId = parentMessage.ParentMessageId
|
||||
}
|
||||
utils.Reverse(messages)
|
||||
return messages, numTokens, tokenCount, nil
|
||||
}
|
||||
|
||||
func (chat *ChatService) getMessageByID(id string) (ChatMessage, bool) {
|
||||
item := chat.store.Get(id)
|
||||
if item == nil {
|
||||
return ChatMessage{}, false
|
||||
}
|
||||
if item.Expired() {
|
||||
return ChatMessage{}, false
|
||||
}
|
||||
return item.Value(), true
|
||||
}
|
||||
23
chatgpt-web-backend/pkg/middlewares/rate_limit.go
Normal file
23
chatgpt-web-backend/pkg/middlewares/rate_limit.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
func RateLimitMiddleware(r rate.Limit, b int) gin.HandlerFunc {
|
||||
limiter := rate.NewLimiter(r, b)
|
||||
return func(c *gin.Context) {
|
||||
if !limiter.Allow() {
|
||||
// 请求被限制,返回错误信息
|
||||
c.JSON(429, gin.H{
|
||||
"status": "Fail",
|
||||
"message": "Too many requests, please try again later",
|
||||
"data": nil,
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
43
chatgpt-web-backend/pkg/tokenizer/tokenizer.go
Normal file
43
chatgpt-web-backend/pkg/tokenizer/tokenizer.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type tokenInfo struct {
|
||||
Code int `json:"code"`
|
||||
Count int `json:"num_tokens"`
|
||||
Msg string `json:"msg"`
|
||||
}
|
||||
|
||||
func GetTokenCount(message openai.ChatCompletionMessage, model string) (int, error) {
|
||||
url := fmt.Sprintf("http://127.0.0.1:5000/tokenizer/%s", model)
|
||||
info := tokenInfo{}
|
||||
if err := postJSON(url, &message, &info); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if info.Code != 200 {
|
||||
return 0, fmt.Errorf("%v", info.Msg)
|
||||
}
|
||||
return info.Count, nil
|
||||
}
|
||||
|
||||
func postJSON(url string, requestData *openai.ChatCompletionMessage, responseData *tokenInfo) error {
|
||||
requestBody, err := json.Marshal(requestData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := http.Post(url, "application/json", bytes.NewBuffer(requestBody))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return json.NewDecoder(resp.Body).Decode(responseData)
|
||||
}
|
||||
65
chatgpt-web-backend/pkg/utils/utils.go
Normal file
65
chatgpt-web-backend/pkg/utils/utils.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"k8s.io/klog/v2"
|
||||
)
|
||||
|
||||
func ReplaceInFile(filePath string, targetStr string, replaceStr string) error {
|
||||
// 读取文件内容
|
||||
content, err := ioutil.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 判断是否需要替换
|
||||
if !strings.Contains(string(content), targetStr) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 替换字符串并写回文件,保持原有 filemode
|
||||
info, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newContent := strings.ReplaceAll(string(content), targetStr, replaceStr)
|
||||
err = ioutil.WriteFile(filePath, []byte(newContent), info.Mode())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
klog.Infof("Replaced in file: %s\n", filePath)
|
||||
return nil
|
||||
}
|
||||
|
||||
func ReplaceFiles(rootDir string, replacePairs map[string]string) error {
|
||||
err := filepath.Walk(rootDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 如果当前路径是目录,则继续遍历
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
// 处理文件
|
||||
for targetStr, replaceStr := range replacePairs {
|
||||
err = ReplaceInFile(path, targetStr, replaceStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func Reverse[S ~[]E, E any](s S) {
|
||||
for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 {
|
||||
s[i], s[j] = s[j], s[i]
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user