254 lines
6.0 KiB
Go
254 lines
6.0 KiB
Go
package config
|
|
|
|
import (
|
|
"bufio"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/spf13/viper"
|
|
)
|
|
|
|
type Config struct {
|
|
Server struct {
|
|
IP string
|
|
Port int
|
|
AccessToken string
|
|
}
|
|
Log struct {
|
|
Level string
|
|
LogPath string `mapstructure:"logPath"`
|
|
} `mapstructure:"log"`
|
|
Chat struct {
|
|
ApiKey string `mapstructure:"api_key"`
|
|
BaseUrl string `mapstructure:"base_url"`
|
|
Model string `mapstructure:"model"`
|
|
MaxTokens int `mapstructure:"max_tokens"`
|
|
Temperature float32 `mapstructure:"temperature"`
|
|
TopP float32 `mapstructure:"top_p"`
|
|
PresencePenalty float32 `mapstructure:"presence_penalty"`
|
|
FrequencyPenalty float32 `mapstructure:"frequency_penalty"`
|
|
BotDesc string `mapstructure:"bot_desc"`
|
|
MinResponseTokens int `mapstructure:"min_response_tokens"`
|
|
ContextTTL int `mapstructure:"context_ttl"`
|
|
ContextLen int `mapstructure:"context_len"`
|
|
}
|
|
Mysql struct {
|
|
DSN string
|
|
MaxLifeTime int
|
|
MaxOpenConn int
|
|
MaxIdleConn int
|
|
}
|
|
Redis struct {
|
|
Host string
|
|
Port int
|
|
Pwd string `mapstructure:"pwd"`
|
|
}
|
|
DependOn struct {
|
|
Sensitive struct {
|
|
Address string
|
|
AccessToken string
|
|
}
|
|
Keywords struct {
|
|
Address string
|
|
AccessToken string
|
|
}
|
|
Tokenizer struct {
|
|
Address string
|
|
}
|
|
}
|
|
Vector struct {
|
|
Provider string
|
|
Threshold float32
|
|
Tencent struct {
|
|
Url string
|
|
Username string
|
|
Pwd string
|
|
Database string
|
|
Timeout int
|
|
MaxIdleConnPerHost int
|
|
ReadConsistency string
|
|
IdleConnTimeout int
|
|
}
|
|
Pgvector struct {
|
|
DSN string `mapstructure:"dsn"`
|
|
Table string `mapstructure:"table"`
|
|
Dimensions int `mapstructure:"dimensions"`
|
|
MaxLifeTime int `mapstructure:"maxLifeTime"`
|
|
MaxOpenConn int `mapstructure:"maxOpenConn"`
|
|
MaxIdleConn int `mapstructure:"maxIdleConn"`
|
|
}
|
|
}
|
|
Embedding struct {
|
|
Provider string
|
|
BaseUrl string `mapstructure:"base_url"`
|
|
ApiKey string `mapstructure:"api_key"`
|
|
Model string `mapstructure:"model"`
|
|
Timeout int
|
|
}
|
|
VectorDB struct {
|
|
Url string
|
|
Username string
|
|
Pwd string
|
|
Database string
|
|
Timeout int
|
|
MaxIdleConnPerHost int
|
|
ReadConsistency string
|
|
IdleConnTimeout int
|
|
}
|
|
}
|
|
|
|
var conf *Config
|
|
|
|
func InitConfig(filePath string, typ ...string) {
|
|
loadProjectDotEnv(filePath)
|
|
v := viper.New()
|
|
v.SetConfigFile(filePath)
|
|
if len(typ) > 0 {
|
|
v.SetConfigType(typ[0])
|
|
}
|
|
err := v.ReadInConfig()
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
conf = &Config{}
|
|
err = v.Unmarshal(conf)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
normalizeConfig(conf)
|
|
|
|
}
|
|
|
|
func GetConfig() *Config {
|
|
return conf
|
|
}
|
|
|
|
func normalizeConfig(conf *Config) {
|
|
if conf.Vector.Provider == "" {
|
|
conf.Vector.Provider = "tencent"
|
|
}
|
|
if conf.Vector.Threshold == 0 {
|
|
conf.Vector.Threshold = 0.99
|
|
}
|
|
|
|
// Backfill the new vector.tencent block from the legacy vectorDB config.
|
|
if conf.Vector.Tencent.Url == "" {
|
|
conf.Vector.Tencent.Url = conf.VectorDB.Url
|
|
}
|
|
if conf.Vector.Tencent.Username == "" {
|
|
conf.Vector.Tencent.Username = conf.VectorDB.Username
|
|
}
|
|
if conf.Vector.Tencent.Pwd == "" {
|
|
conf.Vector.Tencent.Pwd = conf.VectorDB.Pwd
|
|
}
|
|
if conf.Vector.Tencent.Database == "" {
|
|
conf.Vector.Tencent.Database = conf.VectorDB.Database
|
|
}
|
|
if conf.Vector.Tencent.Timeout == 0 {
|
|
conf.Vector.Tencent.Timeout = conf.VectorDB.Timeout
|
|
}
|
|
if conf.Vector.Tencent.MaxIdleConnPerHost == 0 {
|
|
conf.Vector.Tencent.MaxIdleConnPerHost = conf.VectorDB.MaxIdleConnPerHost
|
|
}
|
|
if conf.Vector.Tencent.ReadConsistency == "" {
|
|
conf.Vector.Tencent.ReadConsistency = conf.VectorDB.ReadConsistency
|
|
}
|
|
if conf.Vector.Tencent.IdleConnTimeout == 0 {
|
|
conf.Vector.Tencent.IdleConnTimeout = conf.VectorDB.IdleConnTimeout
|
|
}
|
|
|
|
if conf.Embedding.Provider == "" {
|
|
conf.Embedding.Provider = "openai-compatible"
|
|
}
|
|
if conf.Embedding.BaseUrl == "" {
|
|
conf.Embedding.BaseUrl = conf.Chat.BaseUrl
|
|
}
|
|
if conf.Embedding.ApiKey == "" {
|
|
conf.Embedding.ApiKey = conf.Chat.ApiKey
|
|
}
|
|
if conf.Embedding.Timeout == 0 {
|
|
conf.Embedding.Timeout = 10
|
|
}
|
|
overrideChatFromEnv(conf)
|
|
overrideEmbeddingFromEnv(conf)
|
|
}
|
|
|
|
func overrideChatFromEnv(conf *Config) {
|
|
if value := os.Getenv("AI_CHAT_OPENAI_BASE_URL"); value != "" {
|
|
conf.Chat.BaseUrl = value
|
|
} else if value := os.Getenv("OPENAI_BASE_URL"); value != "" {
|
|
conf.Chat.BaseUrl = value
|
|
}
|
|
|
|
if value := os.Getenv("AI_CHAT_OPENAI_MODEL"); value != "" {
|
|
conf.Chat.Model = value
|
|
} else if value := os.Getenv("OPENAI_MODEL"); value != "" {
|
|
conf.Chat.Model = value
|
|
}
|
|
|
|
if value := os.Getenv("AI_CHAT_OPENAI_API_KEY"); value != "" {
|
|
conf.Chat.ApiKey = value
|
|
return
|
|
}
|
|
if value := os.Getenv("OPENAI_API_KEY"); value != "" {
|
|
conf.Chat.ApiKey = value
|
|
return
|
|
}
|
|
if value := os.Getenv("MOONSHOT_API_KEY"); value != "" {
|
|
conf.Chat.ApiKey = value
|
|
}
|
|
}
|
|
|
|
func overrideEmbeddingFromEnv(conf *Config) {
|
|
if value := os.Getenv("AI_CHAT_EMBEDDING_BASE_URL"); value != "" {
|
|
conf.Embedding.BaseUrl = value
|
|
}
|
|
if value := os.Getenv("AI_CHAT_EMBEDDING_MODEL"); value != "" {
|
|
conf.Embedding.Model = value
|
|
}
|
|
if value := os.Getenv("AI_CHAT_EMBEDDING_API_KEY"); value != "" {
|
|
conf.Embedding.ApiKey = value
|
|
return
|
|
}
|
|
if value := os.Getenv("ZAI_API_KEY"); value != "" {
|
|
conf.Embedding.ApiKey = value
|
|
}
|
|
}
|
|
|
|
func loadProjectDotEnv(configFilePath string) {
|
|
projectRoot := filepath.Dir(filepath.Dir(configFilePath))
|
|
loadDotEnvFile(filepath.Join(projectRoot, ".env"))
|
|
}
|
|
|
|
func loadDotEnvFile(path string) {
|
|
file, err := os.Open(path)
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer file.Close()
|
|
|
|
scanner := bufio.NewScanner(file)
|
|
for scanner.Scan() {
|
|
line := strings.TrimSpace(scanner.Text())
|
|
if line == "" || strings.HasPrefix(line, "#") {
|
|
continue
|
|
}
|
|
key, value, ok := strings.Cut(line, "=")
|
|
if !ok {
|
|
continue
|
|
}
|
|
key = strings.TrimSpace(key)
|
|
value = strings.TrimSpace(value)
|
|
value = strings.Trim(value, `"'`)
|
|
if key == "" {
|
|
continue
|
|
}
|
|
if _, exists := os.LookupEnv(key); exists {
|
|
continue
|
|
}
|
|
_ = os.Setenv(key, value)
|
|
}
|
|
}
|