tokenizer
This commit is contained in:
20
ai-chat-service/chat-server/chat-context/chat_context.go
Normal file
20
ai-chat-service/chat-server/chat-context/chat_context.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package chat_context
|
||||
|
||||
import "github.com/sashabaranov/go-openai"
|
||||
|
||||
type ChatMessage struct {
|
||||
//当前记录ID
|
||||
ID string `json:"id,omitempty"`
|
||||
//上一条记录ID
|
||||
PID string `json:"pid,omitempty"`
|
||||
//消息内容
|
||||
Message openai.ChatCompletionMessage `json:"message"`
|
||||
//该消息tokens数
|
||||
Tokens int `json:"tokens,omitempty"`
|
||||
}
|
||||
|
||||
type ContextCache interface {
|
||||
Get(key string) (*ChatMessage, error)
|
||||
Set(key string, value *ChatMessage, ttl int) error
|
||||
Close()
|
||||
}
|
||||
50
ai-chat-service/chat-server/chat-context/redis.go
Normal file
50
ai-chat-service/chat-server/chat-context/redis.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package chat_context
|
||||
|
||||
import (
|
||||
predis "ai-chat-service/pkg/db/redis"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"time"
|
||||
)
|
||||
|
||||
type redisCache struct {
|
||||
redisClient *redis.Client
|
||||
}
|
||||
|
||||
func NewRedisCache() ContextCache {
|
||||
pool := predis.GetPool()
|
||||
return &redisCache{
|
||||
redisClient: pool.Get(),
|
||||
}
|
||||
}
|
||||
func getRedisKey(key string) string {
|
||||
return predis.GetKey(key)
|
||||
}
|
||||
|
||||
func (c *redisCache) Get(key string) (*ChatMessage, error) {
|
||||
key = getRedisKey(key)
|
||||
str, err := c.redisClient.Get(context.Background(), key).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
value := &ChatMessage{}
|
||||
err = json.Unmarshal([]byte(str), value)
|
||||
return value, err
|
||||
}
|
||||
func (c *redisCache) Set(key string, value *ChatMessage, ttl int) error {
|
||||
key = getRedisKey(key)
|
||||
bytes, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
str := string(bytes)
|
||||
return c.redisClient.SetEx(context.Background(), key, str, time.Duration(ttl)*time.Second).Err()
|
||||
}
|
||||
func (c *redisCache) Close() {
|
||||
pool := predis.GetPool()
|
||||
pool.Put(c.redisClient)
|
||||
}
|
||||
54
ai-chat-service/chat-server/data/chat_records.go
Normal file
54
ai-chat-service/chat-server/data/chat_records.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package data
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type IChatRecordsData interface {
|
||||
Add(record *ChatRecord) error
|
||||
GetById(id int64) (record *ChatRecord, err error)
|
||||
}
|
||||
|
||||
type ChatRecord struct {
|
||||
ID int64 `json:"id"`
|
||||
UserMsg string `json:"user_msg"`
|
||||
UserMsgTokens int `json:"user_msg_tokens"`
|
||||
UserMsgKeywords []string `json:"user_msg_keywords"`
|
||||
AIMsg string `json:"ai_msg"`
|
||||
AIMsgTokens int `json:"ai_msg_tokens"`
|
||||
ReqTokens int `json:"req_tokens"`
|
||||
CreateAt int64 `json:"create_at"`
|
||||
}
|
||||
|
||||
type chatRecordsData struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewChatRecordsData(db *sql.DB) IChatRecordsData {
|
||||
return &chatRecordsData{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (data *chatRecordsData) Add(cr *ChatRecord) (err error) {
|
||||
sqlStr := "insert into chat_records(user_msg,user_msg_tokens,user_msg_keywords,ai_msg,ai_msg_tokens,req_tokens,create_at)values(?,?,?,?,?,?,?)"
|
||||
res, err := data.db.Exec(sqlStr, cr.UserMsg, cr.UserMsgTokens, strings.Join(cr.UserMsgKeywords, ","), cr.AIMsg, cr.AIMsgTokens, cr.ReqTokens, cr.CreateAt)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cr.ID, _ = res.LastInsertId()
|
||||
return
|
||||
}
|
||||
func (data *chatRecordsData) GetById(id int64) (cr *ChatRecord, err error) {
|
||||
sqlStr := "select id,user_msg,user_msg_tokens,user_msg_keywords,ai_msg,ai_msg_tokens,req_tokens,create_at from chat_records where id = ?"
|
||||
row := data.db.QueryRow(sqlStr, id)
|
||||
cr = &ChatRecord{}
|
||||
var keywords string
|
||||
err = row.Scan(&cr.ID, &cr.UserMsg, &cr.UserMsgTokens, &keywords, &cr.AIMsg, &cr.AIMsgTokens, &cr.ReqTokens, &cr.CreateAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cr.UserMsgKeywords = strings.Split(keywords, ",")
|
||||
return cr, err
|
||||
}
|
||||
78
ai-chat-service/chat-server/main.go
Normal file
78
ai-chat-service/chat-server/main.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"ai-chat-service/chat-server/data"
|
||||
metrics_app "ai-chat-service/chat-server/metrics-app"
|
||||
metrics_bus "ai-chat-service/chat-server/metrics-bus"
|
||||
"ai-chat-service/chat-server/server"
|
||||
vector_data "ai-chat-service/chat-server/vector-data"
|
||||
"ai-chat-service/interceptor"
|
||||
"ai-chat-service/pkg/config"
|
||||
"ai-chat-service/pkg/db/mysql"
|
||||
"ai-chat-service/pkg/db/redis"
|
||||
"ai-chat-service/pkg/db/vector"
|
||||
"ai-chat-service/pkg/log"
|
||||
"ai-chat-service/proto"
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/collectors"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/health"
|
||||
"google.golang.org/grpc/health/grpc_health_v1"
|
||||
"net/http"
|
||||
|
||||
"net"
|
||||
)
|
||||
|
||||
var (
|
||||
configFile = flag.String("config", "dev.config.yaml", "")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
registry := prometheus.NewRegistry()
|
||||
registry.MustRegister(collectors.NewGoCollector(), collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}))
|
||||
busMetrics := metrics_bus.NewBusMetrics(registry)
|
||||
|
||||
http.Handle("/metrics", promhttp.HandlerFor(registry, promhttp.HandlerOpts{}))
|
||||
go http.ListenAndServe(":8080", nil)
|
||||
|
||||
//初始化配置文件
|
||||
config.InitConfig(*configFile)
|
||||
cnf := config.GetConfig()
|
||||
//初始化日志
|
||||
log.SetLevel(cnf.Log.Level)
|
||||
log.SetOutput(log.GetRotateWriter(cnf.Log.LogPath))
|
||||
log.SetPrintCaller(true)
|
||||
|
||||
logger := log.NewLogger()
|
||||
logger.SetLevel(cnf.Log.Level)
|
||||
logger.SetOutput(log.GetRotateWriter(cnf.Log.LogPath))
|
||||
logger.SetPrintCaller(true)
|
||||
|
||||
// 初始化Mysql
|
||||
mysql.InitMysql(cnf)
|
||||
// 初始化redis
|
||||
redis.InitRedisPool(cnf)
|
||||
// 初始化向量数据库
|
||||
vector.InitDB(cnf)
|
||||
|
||||
recordsData := data.NewChatRecordsData(mysql.GetDB())
|
||||
|
||||
lis, err := net.Listen("tcp", fmt.Sprintf("%s:%d", cnf.Server.IP, cnf.Server.Port))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
s := grpc.NewServer(grpc.UnaryInterceptor(interceptor.UnaryAuthInterceptor), grpc.StreamInterceptor(metrics_app.NewStreamMiddleware(registry).WrapHandler()))
|
||||
service := server.NewChatService(recordsData, vector_data.NewChatRecordsData(cnf, vector.GetVdb()), cnf, logger, busMetrics)
|
||||
proto.RegisterChatServer(s, service)
|
||||
|
||||
healthCheckSrv := health.NewServer()
|
||||
grpc_health_v1.RegisterHealthServer(s, healthCheckSrv)
|
||||
|
||||
if err = s.Serve(lis); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
82
ai-chat-service/chat-server/metrics-app/metrics_app.go
Normal file
82
ai-chat-service/chat-server/metrics-app/metrics_app.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package metrics_app
|
||||
|
||||
import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"google.golang.org/grpc"
|
||||
"runtime"
|
||||
"time"
|
||||
)
|
||||
|
||||
type StreamMiddleware interface {
|
||||
WrapHandler() grpc.StreamServerInterceptor
|
||||
}
|
||||
type streamMiddleware struct {
|
||||
registry *prometheus.Registry
|
||||
handlerCounter *prometheus.CounterVec
|
||||
handlerDuration *prometheus.SummaryVec
|
||||
handlerAtHour *prometheus.HistogramVec
|
||||
}
|
||||
|
||||
const (
|
||||
NAMESPACE = "ai_chat"
|
||||
SUBSYSTEM = "chat_service"
|
||||
)
|
||||
|
||||
func NewStreamMiddleware(registry *prometheus.Registry) StreamMiddleware {
|
||||
counter := prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: NAMESPACE,
|
||||
Subsystem: SUBSYSTEM,
|
||||
Name: "requests_total",
|
||||
ConstLabels: map[string]string{"app": "ai_chat"},
|
||||
Help: "用于累计请求次数",
|
||||
}, []string{"full_method"})
|
||||
gauge := prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
Namespace: NAMESPACE,
|
||||
Subsystem: SUBSYSTEM,
|
||||
Name: "curr_num_goroutine",
|
||||
ConstLabels: map[string]string{"app": "ai_chat"},
|
||||
Help: "当前存在的goroutine数量",
|
||||
}, func() float64 {
|
||||
return float64(runtime.NumGoroutine())
|
||||
})
|
||||
histogram := prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: NAMESPACE,
|
||||
Subsystem: SUBSYSTEM,
|
||||
Name: "request_hour",
|
||||
ConstLabels: map[string]string{"app": "ai_chat"},
|
||||
Help: "http请求发生在一天之中的哪个小时",
|
||||
Buckets: []float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
|
||||
}, []string{"full_method"})
|
||||
summary := prometheus.NewSummaryVec(prometheus.SummaryOpts{
|
||||
Namespace: NAMESPACE,
|
||||
Subsystem: SUBSYSTEM,
|
||||
Name: "request_duration_ms",
|
||||
ConstLabels: map[string]string{"app": "ai_chat"},
|
||||
Help: "请求时长分布",
|
||||
Objectives: map[float64]float64{0.1: 0.01, 0.5: 0.01, 0.9: 0.01, 0.99: 0.01},
|
||||
}, []string{"full_method"})
|
||||
registry.MustRegister(counter, gauge, histogram, summary)
|
||||
return &streamMiddleware{
|
||||
registry: registry,
|
||||
handlerCounter: counter,
|
||||
handlerDuration: summary,
|
||||
handlerAtHour: histogram,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *streamMiddleware) WrapHandler() grpc.StreamServerInterceptor {
|
||||
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
label := map[string]string{
|
||||
"full_method": info.FullMethod,
|
||||
}
|
||||
s.handlerCounter.With(label).Inc()
|
||||
hour := time.Now().Hour()
|
||||
s.handlerAtHour.With(label).Observe(float64(hour))
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
s.handlerDuration.With(label).Observe(float64(time.Since(start).Milliseconds()))
|
||||
}()
|
||||
err := handler(srv, ss)
|
||||
return err
|
||||
}
|
||||
}
|
||||
53
ai-chat-service/chat-server/metrics-bus/metrics_bus.go
Normal file
53
ai-chat-service/chat-server/metrics-bus/metrics_bus.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package metrics_bus
|
||||
|
||||
import "github.com/prometheus/client_golang/prometheus"
|
||||
|
||||
type BusMetrics struct {
|
||||
QuestionsTotalCounter prometheus.Counter
|
||||
KeywordsQuestionsTotalCounter prometheus.Counter
|
||||
SensitiveQuestionsTotalCounter prometheus.Counter
|
||||
ErrQuestionsTotalCounter prometheus.Counter
|
||||
}
|
||||
|
||||
const (
|
||||
NAMESPACE = "ai_chat"
|
||||
SUBSYSTEM = "chat_service"
|
||||
)
|
||||
|
||||
func NewBusMetrics(registry *prometheus.Registry) *BusMetrics {
|
||||
questionsTotalCounter := prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: NAMESPACE,
|
||||
Subsystem: SUBSYSTEM,
|
||||
Name: "questions_total",
|
||||
ConstLabels: map[string]string{"app": "ai_chat"},
|
||||
Help: "记录用户提交问题的总数,仅包含记录到DB的问题数量",
|
||||
})
|
||||
keywordsQuestionsTotalCounter := prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: NAMESPACE,
|
||||
Subsystem: SUBSYSTEM,
|
||||
Name: "keywords_questions_total",
|
||||
ConstLabels: map[string]string{"app": "ai_chat"},
|
||||
Help: "记录用户提交的包含关键词的问题总数",
|
||||
})
|
||||
sensitiveQuestionsTotalCounter := prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: NAMESPACE,
|
||||
Subsystem: SUBSYSTEM,
|
||||
Name: "sensitive_questions_total",
|
||||
ConstLabels: map[string]string{"app": "ai_chat"},
|
||||
Help: "记录用户提交的触发敏感词的问题总数",
|
||||
})
|
||||
errQuestionsTotalCounter := prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: NAMESPACE,
|
||||
Subsystem: SUBSYSTEM,
|
||||
Name: "err_questions_total",
|
||||
ConstLabels: map[string]string{"app": "ai_chat"},
|
||||
Help: "记录用户提交问题时报错的总数",
|
||||
})
|
||||
registry.MustRegister(questionsTotalCounter, keywordsQuestionsTotalCounter, sensitiveQuestionsTotalCounter, errQuestionsTotalCounter)
|
||||
return &BusMetrics{
|
||||
QuestionsTotalCounter: questionsTotalCounter,
|
||||
KeywordsQuestionsTotalCounter: keywordsQuestionsTotalCounter,
|
||||
SensitiveQuestionsTotalCounter: sensitiveQuestionsTotalCounter,
|
||||
ErrQuestionsTotalCounter: errQuestionsTotalCounter,
|
||||
}
|
||||
}
|
||||
293
ai-chat-service/chat-server/server/app.go
Normal file
293
ai-chat-service/chat-server/server/app.go
Normal file
@@ -0,0 +1,293 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
chat_context "ai-chat-service/chat-server/chat-context"
|
||||
"ai-chat-service/pkg/config"
|
||||
"ai-chat-service/pkg/log"
|
||||
"ai-chat-service/pkg/zerror"
|
||||
"ai-chat-service/proto"
|
||||
"ai-chat-service/services"
|
||||
keywords_filter "ai-chat-service/services/keywords-filter"
|
||||
keywords_proto "ai-chat-service/services/keywords-filter/proto"
|
||||
"ai-chat-service/services/tokenizer"
|
||||
"context"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"time"
|
||||
)
|
||||
|
||||
const ChatPrimedTokens = 2
|
||||
|
||||
type openaiConf struct {
|
||||
ApiKey string
|
||||
BaseUrl string
|
||||
Model string
|
||||
MaxTokens int
|
||||
Temperature float32
|
||||
TopP float32
|
||||
PresencePenalty float32
|
||||
FrequencyPenalty float32
|
||||
BotDesc string
|
||||
ContextTTL int
|
||||
ContextLen int
|
||||
MinResponseTokens int
|
||||
}
|
||||
type app struct {
|
||||
openaiConf *openaiConf
|
||||
log log.ILogger
|
||||
// TODO 内容上下文对象
|
||||
contextCache chat_context.ContextCache
|
||||
}
|
||||
|
||||
func (s *chatService) newApp(in *proto.ChatCompletionRequest, contextCache chat_context.ContextCache) *app {
|
||||
conf := &openaiConf{
|
||||
ApiKey: s.config.Chat.ApiKey,
|
||||
BaseUrl: s.config.Chat.BaseUrl,
|
||||
Model: s.config.Chat.Model,
|
||||
MaxTokens: s.config.Chat.MaxTokens,
|
||||
Temperature: s.config.Chat.Temperature,
|
||||
TopP: s.config.Chat.TopP,
|
||||
PresencePenalty: s.config.Chat.PresencePenalty,
|
||||
FrequencyPenalty: s.config.Chat.FrequencyPenalty,
|
||||
BotDesc: s.config.Chat.BotDesc,
|
||||
ContextTTL: s.config.Chat.ContextTTL,
|
||||
ContextLen: s.config.Chat.ContextLen,
|
||||
MinResponseTokens: s.config.Chat.MinResponseTokens,
|
||||
}
|
||||
if in.ChatParam != nil {
|
||||
if in.ChatParam.Model != "" {
|
||||
conf.Model = in.ChatParam.Model
|
||||
}
|
||||
if in.ChatParam.TopP != 0 {
|
||||
conf.TopP = in.ChatParam.TopP
|
||||
}
|
||||
if in.ChatParam.FrequencyPenalty != 0 {
|
||||
conf.FrequencyPenalty = in.ChatParam.FrequencyPenalty
|
||||
}
|
||||
if in.ChatParam.PresencePenalty != 0 {
|
||||
conf.PresencePenalty = in.ChatParam.PresencePenalty
|
||||
}
|
||||
if in.ChatParam.Temperature != 0 {
|
||||
conf.Temperature = in.ChatParam.Temperature
|
||||
}
|
||||
if in.ChatParam.BotDesc != "" {
|
||||
conf.BotDesc = in.ChatParam.BotDesc
|
||||
}
|
||||
if in.ChatParam.MaxTokens != 0 {
|
||||
conf.MaxTokens = int(in.ChatParam.MaxTokens)
|
||||
}
|
||||
if in.ChatParam.ContextTTL != 0 {
|
||||
conf.ContextTTL = int(in.ChatParam.ContextTTL)
|
||||
}
|
||||
if in.ChatParam.ContextLen != 0 {
|
||||
conf.ContextLen = int(in.ChatParam.ContextLen)
|
||||
}
|
||||
if in.ChatParam.MinResponseTokens != 0 {
|
||||
conf.MinResponseTokens = int(in.ChatParam.MinResponseTokens)
|
||||
}
|
||||
}
|
||||
return &app{
|
||||
openaiConf: conf,
|
||||
log: s.log,
|
||||
contextCache: contextCache,
|
||||
}
|
||||
}
|
||||
func (a *app) getOpenaiClient() *openai.Client {
|
||||
accessToken := a.openaiConf.ApiKey
|
||||
config := openai.DefaultConfig(accessToken)
|
||||
config.BaseURL = a.openaiConf.BaseUrl
|
||||
client := openai.NewClientWithConfig(config)
|
||||
return client
|
||||
}
|
||||
func (a *app) buildChatCompletionRequest(in *proto.ChatCompletionRequest, stream bool) (req openai.ChatCompletionRequest, tokens, currTokens int, currMessage openai.ChatCompletionMessage, err error) {
|
||||
//当前消息
|
||||
currMessage = openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: in.Message,
|
||||
}
|
||||
req = openai.ChatCompletionRequest{
|
||||
Model: a.openaiConf.Model,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
currMessage,
|
||||
},
|
||||
MaxTokens: a.openaiConf.MinResponseTokens,
|
||||
Temperature: a.openaiConf.Temperature,
|
||||
TopP: a.openaiConf.TopP,
|
||||
PresencePenalty: a.openaiConf.PresencePenalty,
|
||||
FrequencyPenalty: a.openaiConf.FrequencyPenalty,
|
||||
Stream: stream,
|
||||
}
|
||||
contextList := make([]*chat_context.ChatMessage, 0)
|
||||
if in.EnableContext {
|
||||
//从缓存中获取上下文信息
|
||||
contextList = a.getContext(in.Pid)
|
||||
}
|
||||
//重构req.Messages
|
||||
tokens, currTokens, req.Messages, err = a.rebuildMessages(contextList, currMessage)
|
||||
if err != nil {
|
||||
a.log.Error(err)
|
||||
return
|
||||
}
|
||||
req.MaxTokens = a.openaiConf.MaxTokens - tokens
|
||||
return
|
||||
}
|
||||
func (a *app) rebuildMessages(contextList []*chat_context.ChatMessage, currMessage openai.ChatCompletionMessage) (tokens, currTokens int, messages []openai.ChatCompletionMessage, err error) {
|
||||
var sysMessage openai.ChatCompletionMessage
|
||||
botTokens := 0
|
||||
if a.openaiConf.BotDesc != "" {
|
||||
sysMessage = openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleSystem,
|
||||
Content: a.openaiConf.BotDesc,
|
||||
}
|
||||
botTokens, err = tokenizer.GetTokens(&sysMessage, a.openaiConf.Model)
|
||||
if err != nil {
|
||||
a.log.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
messages = []openai.ChatCompletionMessage{currMessage}
|
||||
currTokens, err = tokenizer.GetTokens(&currMessage, a.openaiConf.Model)
|
||||
if err != nil {
|
||||
a.log.Error(err)
|
||||
return
|
||||
}
|
||||
if currTokens > a.openaiConf.MaxTokens-a.openaiConf.MinResponseTokens-botTokens-ChatPrimedTokens {
|
||||
err = zerror.NewByMsg("请求消息超限")
|
||||
a.log.Error(err)
|
||||
return
|
||||
}
|
||||
tokens = currTokens + botTokens + ChatPrimedTokens
|
||||
if contextList != nil {
|
||||
for _, item := range contextList {
|
||||
if tokens+item.Tokens+ChatPrimedTokens > a.openaiConf.MaxTokens-a.openaiConf.MinResponseTokens {
|
||||
break
|
||||
}
|
||||
messages = append(messages, item.Message)
|
||||
tokens += item.Tokens + ChatPrimedTokens
|
||||
}
|
||||
}
|
||||
for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 {
|
||||
messages[i], messages[j] = messages[j], messages[i]
|
||||
}
|
||||
if botTokens > 0 {
|
||||
messages = append([]openai.ChatCompletionMessage{sysMessage}, messages...)
|
||||
}
|
||||
return
|
||||
}
|
||||
func (a *app) buildChatCompletionResponse(msg string) *proto.ChatCompletionResponse {
|
||||
res := &proto.ChatCompletionResponse{
|
||||
Id: uuid.New().String(),
|
||||
Object: "chat.completion",
|
||||
Created: time.Now().Unix(),
|
||||
Model: a.openaiConf.Model,
|
||||
Choices: []*proto.ChatCompletionChoice{
|
||||
{
|
||||
Message: &proto.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
Content: msg,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
},
|
||||
},
|
||||
Usage: &proto.Usage{
|
||||
PromptTokens: 0,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: 0,
|
||||
},
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (a *app) buildChatCompletionStreamResponse(id, delta, finishReason string) *proto.ChatCompletionStreamResponse {
|
||||
res := &proto.ChatCompletionStreamResponse{
|
||||
Id: id,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: time.Now().Unix(),
|
||||
Model: a.openaiConf.Model,
|
||||
Choices: []*proto.ChatCompletionStreamChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: &proto.ChatCompletionStreamChoiceDelta{
|
||||
Content: delta,
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
},
|
||||
FinishReason: finishReason,
|
||||
},
|
||||
},
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (a *app) buildChatCompletionStreamResponseList(id, msg string) []*proto.ChatCompletionStreamResponse {
|
||||
list := make([]*proto.ChatCompletionStreamResponse, 0)
|
||||
for _, delta := range msg {
|
||||
list = append(list, a.buildChatCompletionStreamResponse(id, string(delta), ""))
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
func (a *app) getContext(id string) []*chat_context.ChatMessage {
|
||||
maxLen := a.openaiConf.ContextLen
|
||||
list := make([]*chat_context.ChatMessage, 0, maxLen)
|
||||
key := id
|
||||
for i := 0; i < maxLen; i++ {
|
||||
value, err := a.contextCache.Get(key)
|
||||
if err != nil {
|
||||
a.log.Error(err)
|
||||
return nil
|
||||
}
|
||||
if value == nil {
|
||||
break
|
||||
}
|
||||
list = append(list, value)
|
||||
key = value.PID
|
||||
}
|
||||
return list
|
||||
}
|
||||
func (a *app) saveContext(value *chat_context.ChatMessage) error {
|
||||
err := a.contextCache.Set(value.ID, value, a.openaiConf.ContextTTL)
|
||||
if err != nil {
|
||||
a.log.Error(err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (a *app) keywords(in *proto.ChatCompletionRequest) []string {
|
||||
pool := keywords_filter.GetKeywordsClientPool()
|
||||
conn := pool.Get()
|
||||
defer pool.Put(conn)
|
||||
accessToken := config.GetConfig().DependOn.Keywords.AccessToken
|
||||
client := keywords_proto.NewFilterClient(conn)
|
||||
ctx := services.AppendBearerTokenToContext(context.Background(), accessToken)
|
||||
req := &keywords_proto.FilterReq{
|
||||
Text: in.Message,
|
||||
}
|
||||
res, err := client.FindAll(ctx, req)
|
||||
if err != nil {
|
||||
a.log.Error(err)
|
||||
return []string{}
|
||||
}
|
||||
return res.Keywords
|
||||
|
||||
}
|
||||
func (a *app) sensitive(in *proto.ChatCompletionRequest) (ok bool, msg string, err error) {
|
||||
pool := keywords_filter.GetSensitiveClientPool()
|
||||
conn := pool.Get()
|
||||
defer pool.Put(conn)
|
||||
accessToken := config.GetConfig().DependOn.Sensitive.AccessToken
|
||||
client := keywords_proto.NewFilterClient(conn)
|
||||
ctx := services.AppendBearerTokenToContext(context.Background(), accessToken)
|
||||
req := &keywords_proto.FilterReq{
|
||||
Text: in.Message,
|
||||
}
|
||||
res, err := client.Validate(ctx, req)
|
||||
if err != nil {
|
||||
a.log.Error(err)
|
||||
return false, "", err
|
||||
}
|
||||
ok = res.Ok
|
||||
if !ok {
|
||||
msg = "触发到了知识盲区,请换个问题再问"
|
||||
}
|
||||
return
|
||||
}
|
||||
352
ai-chat-service/chat-server/server/server.go
Normal file
352
ai-chat-service/chat-server/server/server.go
Normal file
@@ -0,0 +1,352 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
chat_context "ai-chat-service/chat-server/chat-context"
|
||||
"ai-chat-service/chat-server/data"
|
||||
metrics_bus "ai-chat-service/chat-server/metrics-bus"
|
||||
vector_data "ai-chat-service/chat-server/vector-data"
|
||||
"ai-chat-service/pkg/config"
|
||||
"ai-chat-service/pkg/log"
|
||||
"ai-chat-service/proto"
|
||||
"ai-chat-service/services/tokenizer"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"github.com/golang/protobuf/jsonpb"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type chatService struct {
|
||||
proto.UnimplementedChatServer
|
||||
config *config.Config
|
||||
log log.ILogger
|
||||
data data.IChatRecordsData
|
||||
vectorData vector_data.IChatRecordsData
|
||||
busMetrics *metrics_bus.BusMetrics
|
||||
}
|
||||
|
||||
func NewChatService(data data.IChatRecordsData, vectorData vector_data.IChatRecordsData, config *config.Config, log log.ILogger, busMetrics *metrics_bus.BusMetrics) proto.ChatServer {
|
||||
return &chatService{
|
||||
config: config,
|
||||
log: log,
|
||||
data: data,
|
||||
vectorData: vectorData,
|
||||
busMetrics: busMetrics,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *chatService) ChatCompletion(ctx context.Context, in *proto.ChatCompletionRequest) (*proto.ChatCompletionResponse, error) {
|
||||
redisContextCache := chat_context.NewRedisCache()
|
||||
defer redisContextCache.Close()
|
||||
|
||||
app := s.newApp(in, redisContextCache)
|
||||
//敏感词过滤
|
||||
ok, msg, err := app.sensitive(in)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
res := app.buildChatCompletionResponse(msg)
|
||||
return res, nil
|
||||
}
|
||||
|
||||
//关键词提取
|
||||
keywords := app.keywords(in)
|
||||
if len(keywords) > 0 {
|
||||
idStr, score, err := s.vectorData.QueryData(context.Background(), map[string][]string{"keywords": {strings.Join(keywords, ",")}})
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
} else if score > 0.99 {
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
} else {
|
||||
record, err := s.data.GetById(id)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
} else {
|
||||
res := app.buildChatCompletionResponse(record.AIMsg)
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
client := app.getOpenaiClient()
|
||||
req, tokens, currTokens, currMessage, err := app.buildChatCompletionRequest(in, false)
|
||||
resp, err := client.CreateChatCompletion(ctx, req)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
res := &proto.ChatCompletionResponse{}
|
||||
bytes, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
err = jsonpb.UnmarshalString(string(bytes), res)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
go func() {
|
||||
reqContext := &chat_context.ChatMessage{
|
||||
ID: in.Id,
|
||||
PID: in.Pid,
|
||||
Message: currMessage,
|
||||
Tokens: currTokens,
|
||||
}
|
||||
err := app.saveContext(reqContext)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return
|
||||
}
|
||||
resContext := &chat_context.ChatMessage{
|
||||
ID: resp.ID,
|
||||
PID: reqContext.ID,
|
||||
Message: resp.Choices[0].Message,
|
||||
Tokens: resp.Usage.CompletionTokens,
|
||||
}
|
||||
err = app.saveContext(resContext)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
records := &data.ChatRecord{
|
||||
UserMsg: in.Message,
|
||||
UserMsgTokens: currTokens,
|
||||
UserMsgKeywords: keywords,
|
||||
AIMsg: resp.Choices[0].Message.Content,
|
||||
AIMsgTokens: resp.Usage.CompletionTokens,
|
||||
ReqTokens: tokens,
|
||||
CreateAt: time.Now().Unix(),
|
||||
}
|
||||
err := s.data.Add(records)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return
|
||||
}
|
||||
//保存到向量数据库
|
||||
if len(keywords) > 0 {
|
||||
list := []*vector_data.ChatRecord{
|
||||
{
|
||||
ID: strconv.FormatInt(records.ID, 10),
|
||||
KVs: map[string]string{
|
||||
"keywords": strings.Join(keywords, ","),
|
||||
},
|
||||
},
|
||||
}
|
||||
err = s.vectorData.UpsertData(context.Background(), list)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return res, err
|
||||
}
|
||||
func (s *chatService) ChatCompletionStream(in *proto.ChatCompletionRequest, stream proto.Chat_ChatCompletionStreamServer) error {
|
||||
redisContextCache := chat_context.NewRedisCache()
|
||||
defer redisContextCache.Close()
|
||||
|
||||
app := s.newApp(in, redisContextCache)
|
||||
//敏感词过滤
|
||||
ok, msg, err := app.sensitive(in)
|
||||
if err != nil {
|
||||
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||
s.log.Error(err)
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
s.busMetrics.SensitiveQuestionsTotalCounter.Inc()
|
||||
resId := uuid.New().String()
|
||||
startRes := app.buildChatCompletionStreamResponse(resId, "", "")
|
||||
endRes := app.buildChatCompletionStreamResponse(resId, "", "stop")
|
||||
err = stream.Send(startRes)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return err
|
||||
}
|
||||
resList := app.buildChatCompletionStreamResponseList(resId, msg)
|
||||
for _, res := range resList {
|
||||
err = stream.Send(res)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
err = stream.Send(endRes)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//关键词提取
|
||||
keywords := app.keywords(in)
|
||||
|
||||
if len(keywords) > 0 {
|
||||
s.busMetrics.KeywordsQuestionsTotalCounter.Inc()
|
||||
idStr, score, err := s.vectorData.QueryData(context.Background(), map[string][]string{"keywords": {strings.Join(keywords, ",")}})
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
} else if score > 0.99 {
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
} else {
|
||||
record, err := s.data.GetById(id)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
} else {
|
||||
resId := uuid.New().String()
|
||||
startRes := app.buildChatCompletionStreamResponse(resId, "", "")
|
||||
endRes := app.buildChatCompletionStreamResponse(resId, "", "stop")
|
||||
err = stream.Send(startRes)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return err
|
||||
}
|
||||
resList := app.buildChatCompletionStreamResponseList(resId, record.AIMsg)
|
||||
for _, res := range resList {
|
||||
err = stream.Send(res)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
err = stream.Send(endRes)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
client := app.getOpenaiClient()
|
||||
req, tokens, currTokens, currMessage, err := app.buildChatCompletionRequest(in, false)
|
||||
chatStream, err := client.CreateChatCompletionStream(stream.Context(), req)
|
||||
if err != nil {
|
||||
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||
s.log.Error(err)
|
||||
return err
|
||||
}
|
||||
defer chatStream.Close()
|
||||
completionContent := ""
|
||||
resultID := ""
|
||||
for {
|
||||
resp, err := chatStream.Recv()
|
||||
if err != nil && err != io.EOF {
|
||||
s.log.Error(err)
|
||||
return err
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if resultID == "" {
|
||||
resultID = resp.ID
|
||||
}
|
||||
completionContent += resp.Choices[0].Delta.Content
|
||||
res := &proto.ChatCompletionStreamResponse{}
|
||||
bytes, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return err
|
||||
}
|
||||
err = jsonpb.UnmarshalString(string(bytes), res)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return err
|
||||
}
|
||||
err = stream.Send(res)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
resultMessage := openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
Content: completionContent,
|
||||
}
|
||||
model := s.config.Chat.Model
|
||||
if in.ChatParam != nil && in.ChatParam.Model != "" {
|
||||
model = in.ChatParam.Model
|
||||
}
|
||||
resultTokens, err := tokenizer.GetTokens(&resultMessage, model)
|
||||
if err != nil {
|
||||
s.busMetrics.ErrQuestionsTotalCounter.Inc()
|
||||
s.log.Error(err)
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
reqContext := &chat_context.ChatMessage{
|
||||
ID: in.Id,
|
||||
PID: in.Pid,
|
||||
Message: currMessage,
|
||||
Tokens: currTokens,
|
||||
}
|
||||
err := app.saveContext(reqContext)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return
|
||||
}
|
||||
resContext := &chat_context.ChatMessage{
|
||||
ID: resultID,
|
||||
PID: reqContext.ID,
|
||||
Message: resultMessage,
|
||||
Tokens: resultTokens,
|
||||
}
|
||||
err = app.saveContext(resContext)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
s.busMetrics.QuestionsTotalCounter.Inc()
|
||||
records := &data.ChatRecord{
|
||||
UserMsg: in.Message,
|
||||
UserMsgTokens: currTokens,
|
||||
UserMsgKeywords: keywords,
|
||||
AIMsg: completionContent,
|
||||
AIMsgTokens: resultTokens,
|
||||
ReqTokens: tokens,
|
||||
CreateAt: time.Now().Unix(),
|
||||
}
|
||||
err := s.data.Add(records)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return
|
||||
}
|
||||
//保存到向量数据库
|
||||
if len(keywords) > 0 {
|
||||
list := []*vector_data.ChatRecord{
|
||||
{
|
||||
ID: strconv.FormatInt(records.ID, 10),
|
||||
KVs: map[string]string{
|
||||
"keywords": strings.Join(keywords, ","),
|
||||
},
|
||||
},
|
||||
}
|
||||
err = s.vectorData.UpsertData(context.Background(), list)
|
||||
if err != nil {
|
||||
s.log.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
69
ai-chat-service/chat-server/vector-data/chat_records.go
Normal file
69
ai-chat-service/chat-server/vector-data/chat_records.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package vector_data
|
||||
|
||||
import (
|
||||
"ai-chat-service/pkg/config"
|
||||
"context"
|
||||
"github.com/tencent/vectordatabase-sdk-go/tcvectordb"
|
||||
)
|
||||
|
||||
const CHAT_RECORDS = "chat_records"
|
||||
|
||||
type ChatRecord struct {
|
||||
ID string
|
||||
KVs map[string]string
|
||||
}
|
||||
type IChatRecordsData interface {
|
||||
UpsertData(ctx context.Context, list []*ChatRecord) error
|
||||
QueryData(ctx context.Context, text map[string][]string) (id string, score float32, err error)
|
||||
}
|
||||
|
||||
type chatRecordsData struct {
|
||||
config *config.Config
|
||||
vectorDB *tcvectordb.Client
|
||||
}
|
||||
|
||||
func NewChatRecordsData(config *config.Config, vectorDB *tcvectordb.Client) IChatRecordsData {
|
||||
return &chatRecordsData{
|
||||
config: config,
|
||||
vectorDB: vectorDB,
|
||||
}
|
||||
}
|
||||
func (data *chatRecordsData) UpsertData(ctx context.Context, list []*ChatRecord) error {
|
||||
database := data.config.VectorDB.Database
|
||||
collection := CHAT_RECORDS
|
||||
coll := data.vectorDB.Database(database).Collection(collection)
|
||||
documentList := make([]tcvectordb.Document, 0, len(list))
|
||||
for _, l := range list {
|
||||
doc := tcvectordb.Document{
|
||||
Id: l.ID,
|
||||
}
|
||||
doc.Fields = make(map[string]tcvectordb.Field, len(l.KVs))
|
||||
for k, v := range l.KVs {
|
||||
doc.Fields[k] = tcvectordb.Field{Val: v}
|
||||
}
|
||||
documentList = append(documentList, doc)
|
||||
}
|
||||
_, err := coll.Upsert(ctx, documentList)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (data *chatRecordsData) QueryData(ctx context.Context, text map[string][]string) (id string, score float32, err error) {
|
||||
database := data.config.VectorDB.Database
|
||||
collection := CHAT_RECORDS
|
||||
coll := data.vectorDB.Database(database).Collection(collection)
|
||||
result, err := coll.SearchByText(ctx, text, &tcvectordb.SearchDocumentParams{
|
||||
Params: &tcvectordb.SearchDocParams{Ef: 100},
|
||||
Limit: 1,
|
||||
})
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
if len(result.Documents) > 0 && len(result.Documents[0]) > 0 {
|
||||
doc := result.Documents[0][0]
|
||||
return doc.Id, doc.Score, nil
|
||||
|
||||
}
|
||||
return "", 0, nil
|
||||
}
|
||||
Reference in New Issue
Block a user