diff --git a/.gitignore b/.gitignore index adf8f72..7502cf6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,23 +1,20 @@ -# ---> Go -# If you prefer the allow list template instead of the deny list, see community template: -# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore -# -# Binaries for programs and plugins -*.exe -*.exe~ -*.dll -*.so -*.dylib - -# Test binary, built with `go test -c` -*.test - -# Output of the go coverage tool, specifically when used with LiteIDE -*.out - -# Dependency directories (remove the comment below to include it) -# vendor/ - -# Go workspace file -go.work +# Logs +logs +*.log +# Editor directories and files +.vscode/* +!.vscode/extensions.json +.idea +.DS_Store +*.suo +*.ntvs* +*.njsproj +*.sln +*.sw? +tmp +bin +data +config.toml +static/upload +storage.json diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..e72ae93 --- /dev/null +++ b/Makefile @@ -0,0 +1,15 @@ +SHELL=/usr/bin/env bash +NAME := geekai +all: amd64 arm64 + +amd64: + CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/$(NAME)-linux main.go +.PHONY: amd64 + +arm64: + CGO_ENABLED=0 GOOS=linux GOARCH=arm64 GOARM=7 go build -o bin/$(NAME)-linux main.go +.PHONY: arm64 + +clean: + rm -rf bin/$(NAME)-* +.PHONY: clean diff --git a/README.md b/README.md index 17eed04..98fde2c 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,5 @@ -# ai-go +# chatgpt-AI-go + + 后端 API Go 语言实现。技术选型采用 Gin + Mysql 架构,依赖注入使用的是 fx 框架,ORM 采用的是 GORM 框架。 + diff --git a/config.sample.toml b/config.sample.toml new file mode 100644 index 0000000..2b9642d --- /dev/null +++ b/config.sample.toml @@ -0,0 +1,125 @@ +Listen = "0.0.0.0:5678" +ProxyURL = "" # 如 http://127.0.0.1:7777 +MysqlDns = "root:12345678@tcp(172.22.11.200:3307)/chatgpt_plus?charset=utf8mb4&collation=utf8mb4_unicode_ci&parseTime=True&loc=Local" +StaticDir = "./static" # 静态资源的目录 +StaticUrl = "/static" # 静态资源访问 URL +AesEncryptKey = "" +WeChatBot = false + +[Session] + SecretKey = "azyehq3ivunjhbntz78isj00i4hz2mt9xtddysfucxakadq4qbfrt0b7q3lnvg80" # 注意:这个是 JWT Token 授权密钥,生产环境请务必更换 + MaxAge = 86400 + +[Redis] # redis 配置信息 + Host = "localhost" + Port = 6379 + Password = "" + DB = 0 + +[ApiConfig] # 微博热搜,今日头条等函数服务 API 配置,此为第三方插件服务,如需使用请联系作者开通 + ApiURL = "" + AppId = "" + Token = "" + + +[SMS] # Sms 配置,用于发送短信 + Active = "Ali" # 当前启用的短信服务,默认使用阿里云 + [SMS.Bao] + Username = "" + Password = "" + Domain = "api.smsbao.com" + Sign = "【极客学长】" + CodeTemplate = "您的验证码是{code}。5分钟有效,若非本人操作,请忽略本短信。" + [SMS.Ali] + AccessKey = "" + AccessSecret = "" + Product = "Dysmsapi" + Domain = "dysmsapi.aliyuncs.com" + Sign = "" + CodeTempId = "" + +[OSS] # OSS 配置,用于存储 MJ 绘画图片 + Active = "local" # 默认使用本地文件存储引擎 + [OSS.Local] + BasePath = "./static/upload" # 本地文件上传根路径 + BaseURL = "http://localhost:5678/static/upload" # 本地上传文件前缀 URL,线上需要把 localhost 替换成自己的实际域名或者IP + [OSS.Minio] + Endpoint = "" # 如 172.22.11.200:9000 + AccessKey = "" # 自己去 Minio 控制台去创建一个 Access Key + AccessSecret = "" + Bucket = "chatgpt-plus" # 替换为你自己创建的 Bucket,注意要给 Bucket 设置公开的读权限,否则会出现图片无法显示。 + UseSSL = false + Domain = "" # 地址必须是能够通过公网访问的,否则会出现图片无法显示。 + [OSS.QiNiu] # 七牛云 OSS 配置 + Zone = "z2" # 区域,z0:华东,z1: 华北,na0:北美,as0:新加坡 + AccessKey = "" + AccessSecret = "" + Bucket = "" + Domain = "" # OSS Bucket 所绑定的域名,如 https://img.r9it.com + [OSS.AliYun] + Endpoint = "oss-cn-hangzhou.aliyuncs.com" + AccessKey = "" + AccessSecret = "" + Bucket = "chatgpt-plus" + SubDir = "" + Domain = "" + +[[MjProxyConfigs]] + Enabled = true + ApiURL = "http://midjourney-proxy:8082" + ApiKey = "sk-geekmaster" + +[[MjPlusConfigs]] + Enabled = false + ApiURL = "https://api.chat-plus.net" + Mode = "fast" # MJ 绘画模式,可选值 relax/fast/turbo + ApiKey = "sk-xxx" + +[[SdConfigs]] + Enabled = false + ApiURL = "" + ApiKey = "" + Txt2ImgJsonPath = "res/sd/text2img.json" + +[XXLConfig] # xxl-job 配置,需要你部署 XXL-JOB 定时任务工具,用来定期清理未支付订单和清理过期 VIP,如果你没有启用支付服务,则该服务也无需启动 + Enabled = false # 是否启用 XXL JOB 服务 + ServerAddr = "http://172.22.11.47:8080/xxl-job-admin" # xxl-job-admin 管理地址 + ExecutorIp = "172.22.11.47" # 执行器 IP 地址 + ExecutorPort = "9999" # 执行器服务端口 + AccessToken = "xxl-job-api-token" # 执行器 API 通信 token + RegistryKey = "chatgpt-plus" # 任务注册 key + +[AlipayConfig] + Enabled = false # 启用支付宝支付通道 + SandBox = false # 是否启用沙盒模式 + UserId = "2088721020750581" # 商户ID + AppId = "9021000131658023" # App Id + PrivateKey = "certs/alipay/privateKey.txt" # 应用私钥 + PublicKey = "certs/alipay/appPublicCert.crt" # 应用公钥证书 + AlipayPublicKey = "certs/alipay/alipayPublicCert.crt" # 支付宝公钥证书 + RootCert = "certs/alipay/alipayRootCert.crt" # 支付宝根证书 + NotifyURL = "https://ai.r9it.com/api/payment/alipay/notify" # 支付异步回调地址 + +[HuPiPayConfig] + Enabled = false + Name = "wechat" + AppId = "" + AppSecret = "" + ApiURL = "https://api.xunhupay.com" + NotifyURL = "https://ai.r9it.com/api/payment/hupipay/notify" + +[SmtpConfig] # 注意,阿里云服务器禁用了25号端口,请使用 465 端口,并开启 TLS 连接 + UseTls = false + Host = "smtp.163.com" + Port = 25 + AppName = "极客学长" + From = "test@163.com" # 发件邮箱人地址 + Password = "" #邮箱 stmp 服务授权码 + +[JPayConfig] # PayJs 支付配置 + Enabled = false + Name = "wechat" # 请不要改动 + AppId = "" # 商户 ID + PrivateKey = "" # 秘钥 + ApiURL = "https://payjs.cn" + NotifyURL = "https://ai.r9it.com/api/payment/payjs/notify" # 异步回调地址,域名改成你自己的 \ No newline at end of file diff --git a/core/app_server.go b/core/app_server.go new file mode 100644 index 0000000..8ba434f --- /dev/null +++ b/core/app_server.go @@ -0,0 +1,380 @@ +package core + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "bytes" + "context" + "fmt" + "geekai/core/types" + "geekai/store/model" + "geekai/utils" + "geekai/utils/resp" + "github.com/gin-gonic/gin" + "github.com/go-redis/redis/v8" + "github.com/golang-jwt/jwt/v5" + "github.com/nfnt/resize" + "golang.org/x/image/webp" + "gorm.io/gorm" + "image" + "image/jpeg" + "io" + "net/http" + "os" + "runtime/debug" + "strings" + "time" +) + +type AppServer struct { + Debug bool + Config *types.AppConfig + Engine *gin.Engine + ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message + + SysConfig *types.SystemConfig // system config cache + + // 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次 + // 防止第三方直接连接 socket 调用 OpenAI API + ChatSession *types.LMap[string, *types.ChatSession] //map[sessionId]UserId + ChatClients *types.LMap[string, *types.WsClient] // map[sessionId]Websocket 连接集合 + ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function +} + +func NewServer(appConfig *types.AppConfig) *AppServer { + gin.SetMode(gin.ReleaseMode) + gin.DefaultWriter = io.Discard + return &AppServer{ + Debug: false, + Config: appConfig, + Engine: gin.Default(), + ChatContexts: types.NewLMap[string, []types.Message](), + ChatSession: types.NewLMap[string, *types.ChatSession](), + ChatClients: types.NewLMap[string, *types.WsClient](), + ReqCancelFunc: types.NewLMap[string, context.CancelFunc](), + } +} + +func (s *AppServer) Init(debug bool, client *redis.Client) { + if debug { // 调试模式允许跨域请求 API + s.Debug = debug + logger.Info("Enabled debug mode") + } + s.Engine.Use(corsMiddleware()) + s.Engine.Use(staticResourceMiddleware()) + s.Engine.Use(authorizeMiddleware(s, client)) + s.Engine.Use(parameterHandlerMiddleware()) + s.Engine.Use(errorHandler) + // 添加静态资源访问 + s.Engine.Static("/static", s.Config.StaticDir) +} + +func (s *AppServer) Run(db *gorm.DB) error { + // load system configs + var sysConfig model.Config + res := db.Where("marker", "system").First(&sysConfig) + if res.Error != nil { + return res.Error + } + err := utils.JsonDecode(sysConfig.Config, &s.SysConfig) + if err != nil { + return err + } + logger.Infof("http://%s", s.Config.Listen) + return s.Engine.Run(s.Config.Listen) +} + +// 全局异常处理 +func errorHandler(c *gin.Context) { + defer func() { + if r := recover(); r != nil { + logger.Errorf("Handler Panic: %v", r) + debug.PrintStack() + c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: types.ErrorMsg}) + c.Abort() + } + }() + //加载完 defer recover,继续后续接口调用 + c.Next() +} + +// 跨域中间件设置 +func corsMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + method := c.Request.Method + origin := c.Request.Header.Get("Origin") + if origin != "" { + // 设置允许的请求源 + c.Header("Access-Control-Allow-Origin", origin) + c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE") + //允许跨域设置可以返回其他子段,可以自定义字段 + c.Header("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, Chat-Token, Admin-Authorization") + // 允许浏览器(客户端)可以解析的头部 (重要) + c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers") + //设置缓存时间 + c.Header("Access-Control-Max-Age", "172800") + //允许客户端传递校验信息比如 cookie (重要) + c.Header("Access-Control-Allow-Credentials", "true") + } + + if method == http.MethodOptions { + c.JSON(http.StatusOK, "ok!") + } + + defer func() { + if err := recover(); err != nil { + logger.Info("Panic info is: %v", err) + } + }() + + c.Next() + } +} + +// 用户授权验证 +func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc { + return func(c *gin.Context) { + var tokenString string + isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/") + if isAdminApi { // 后台管理 API + tokenString = c.GetHeader(types.AdminAuthHeader) + } else if c.Request.URL.Path == "/api/chat/new" { + tokenString = c.Query("token") + } else { + tokenString = c.GetHeader(types.UserAuthHeader) + } + + if tokenString == "" { + if needLogin(c) { + resp.ERROR(c, "You should put Authorization in request headers") + c.Abort() + return + } else { // 直接放行 + c.Next() + return + } + } + + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok && needLogin(c) { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + if isAdminApi { + return []byte(s.Config.AdminSession.SecretKey), nil + } else { + return []byte(s.Config.Session.SecretKey), nil + } + + }) + + if err != nil && needLogin(c) { + resp.NotAuth(c, fmt.Sprintf("Error with parse auth token: %v", err)) + c.Abort() + return + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || !token.Valid && needLogin(c) { + resp.NotAuth(c, "Token is invalid") + c.Abort() + return + } + + expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0) + if expr > 0 && int64(expr) < time.Now().Unix() && needLogin(c) { + resp.NotAuth(c, "Token is expired") + c.Abort() + return + } + + key := fmt.Sprintf("users/%v", claims["user_id"]) + if isAdminApi { + key = fmt.Sprintf("admin/%v", claims["user_id"]) + } + if _, err := client.Get(context.Background(), key).Result(); err != nil && needLogin(c) { + resp.NotAuth(c, "Token is not found in redis") + c.Abort() + return + } + c.Set(types.LoginUserID, claims["user_id"]) + } +} + +func needLogin(c *gin.Context) bool { + if c.Request.URL.Path == "/api/user/login" || + c.Request.URL.Path == "/api/user/logout" || + c.Request.URL.Path == "/api/user/resetPass" || + c.Request.URL.Path == "/api/admin/login" || + c.Request.URL.Path == "/api/admin/logout" || + c.Request.URL.Path == "/api/admin/login/captcha" || + c.Request.URL.Path == "/api/user/register" || + c.Request.URL.Path == "/api/user/session" || + c.Request.URL.Path == "/api/chat/history" || + c.Request.URL.Path == "/api/chat/detail" || + c.Request.URL.Path == "/api/chat/list" || + c.Request.URL.Path == "/api/role/list" || + c.Request.URL.Path == "/api/model/list" || + c.Request.URL.Path == "/api/mj/imgWall" || + c.Request.URL.Path == "/api/mj/client" || + c.Request.URL.Path == "/api/mj/notify" || + c.Request.URL.Path == "/api/invite/hits" || + c.Request.URL.Path == "/api/sd/imgWall" || + c.Request.URL.Path == "/api/sd/client" || + c.Request.URL.Path == "/api/dall/imgWall" || + c.Request.URL.Path == "/api/dall/client" || + c.Request.URL.Path == "/api/product/list" || + c.Request.URL.Path == "/api/menu/list" || + c.Request.URL.Path == "/api/markMap/client" || + c.Request.URL.Path == "/api/payment/alipay/notify" || + c.Request.URL.Path == "/api/payment/hupipay/notify" || + c.Request.URL.Path == "/api/payment/payjs/notify" || + c.Request.URL.Path == "/api/payment/doPay" || + c.Request.URL.Path == "/api/payment/payWays" || + strings.HasPrefix(c.Request.URL.Path, "/api/test") || + strings.HasPrefix(c.Request.URL.Path, "/api/config/") || + strings.HasPrefix(c.Request.URL.Path, "/api/function/") || + strings.HasPrefix(c.Request.URL.Path, "/api/sms/") || + strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") || + strings.HasPrefix(c.Request.URL.Path, "/static/") { + return false + } + return true +} + +// 统一参数处理 +func parameterHandlerMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + // GET 参数处理 + params := c.Request.URL.Query() + for key, values := range params { + for i, value := range values { + params[key][i] = strings.TrimSpace(value) + } + } + // update get parameters + c.Request.URL.RawQuery = params.Encode() + // skip file upload requests + contentType := c.Request.Header.Get("Content-Type") + if strings.Contains(contentType, "multipart/form-data") { + c.Next() + return + } + + if strings.Contains(contentType, "application/json") { + // process POST JSON request body + bodyBytes, err := io.ReadAll(c.Request.Body) + if err != nil { + c.Next() + return + } + + // 还原请求体 + c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + // 将请求体解析为 JSON + var jsonData map[string]interface{} + if err := c.ShouldBindJSON(&jsonData); err != nil { + c.Next() + return + } + + // 对 JSON 数据中的字符串值去除两端空格 + trimJSONStrings(jsonData) + // 更新请求体 + c.Request.Body = io.NopCloser(bytes.NewBufferString(utils.JsonEncode(jsonData))) + } + + c.Next() + } +} + +// 递归对 JSON 数据中的字符串值去除两端空格 +func trimJSONStrings(data interface{}) { + switch v := data.(type) { + case map[string]interface{}: + for key, value := range v { + switch valueType := value.(type) { + case string: + v[key] = strings.TrimSpace(valueType) + case map[string]interface{}, []interface{}: + trimJSONStrings(value) + } + } + case []interface{}: + for i, value := range v { + switch valueType := value.(type) { + case string: + v[i] = strings.TrimSpace(valueType) + case map[string]interface{}, []interface{}: + trimJSONStrings(value) + } + } + } +} + +// 静态资源中间件 +func staticResourceMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + + url := c.Request.URL.String() + // 拦截生成缩略图请求 + if strings.HasPrefix(url, "/static/") && strings.Contains(url, "?imageView2") { + r := strings.SplitAfter(url, "imageView2") + size := strings.Split(r[1], "/") + if len(size) != 8 { + c.String(http.StatusNotFound, "invalid thumb args") + return + } + with := utils.IntValue(size[3], 0) + height := utils.IntValue(size[5], 0) + quality := utils.IntValue(size[7], 75) + + // 打开图片文件 + filePath := strings.TrimLeft(c.Request.URL.Path, "/") + file, err := os.Open(filePath) + if err != nil { + c.String(http.StatusNotFound, "Image not found") + return + } + defer file.Close() + + // 解码图片 + img, _, err := image.Decode(file) + // for .webp image + if err != nil { + img, err = webp.Decode(file) + } + if err != nil { + c.String(http.StatusInternalServerError, "Error decoding image") + return + } + + var newImg image.Image + if height == 0 || with == 0 { + // 固定宽度,高度自适应 + newImg = resize.Resize(uint(with), uint(height), img, resize.Lanczos3) + } else { + // 生成缩略图 + newImg = resize.Thumbnail(uint(with), uint(height), img, resize.Lanczos3) + } + var buffer bytes.Buffer + err = jpeg.Encode(&buffer, newImg, &jpeg.Options{Quality: quality}) + if err != nil { + logger.Error(err) + c.String(http.StatusInternalServerError, err.Error()) + return + } + + // 设置图片缓存有效期为一年 (365天) + c.Header("Cache-Control", "max-age=31536000, public") + // 直接输出图像数据流 + c.Data(http.StatusOK, "image/jpeg", buffer.Bytes()) + c.Abort() // 中断请求 + } + c.Next() + } +} diff --git a/core/config.go b/core/config.go new file mode 100644 index 0000000..0416049 --- /dev/null +++ b/core/config.go @@ -0,0 +1,77 @@ +package core + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "bytes" + "geekai/core/types" + logger2 "geekai/logger" + "geekai/utils" + "os" + + "github.com/BurntSushi/toml" +) + +var logger = logger2.GetLogger() + +func NewDefaultConfig() *types.AppConfig { + return &types.AppConfig{ + Listen: "0.0.0.0:5678", + ProxyURL: "", + StaticDir: "./static", + StaticUrl: "http://localhost/5678/static", + Redis: types.RedisConfig{Host: "localhost", Port: 6379, Password: ""}, + Session: types.Session{ + SecretKey: utils.RandString(64), + MaxAge: 86400, + }, + ApiConfig: types.ApiConfig{}, + OSS: types.OSSConfig{ + Active: "local", + Local: types.LocalStorageConfig{ + BaseURL: "http://localhost/5678/static/upload", + BasePath: "./static/upload", + }, + }, + WeChatBot: false, + AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false}, + } +} + +func LoadConfig(configFile string) (*types.AppConfig, error) { + var config *types.AppConfig + _, err := os.Stat(configFile) + if err != nil { + logger.Info("creating new config file: ", configFile) + config = NewDefaultConfig() + config.Path = configFile + // save config + err := SaveConfig(config) + if err != nil { + return nil, err + } + + return config, nil + } + _, err = toml.DecodeFile(configFile, &config) + if err != nil { + return nil, err + } + + return config, err +} + +func SaveConfig(config *types.AppConfig) error { + buf := new(bytes.Buffer) + encoder := toml.NewEncoder(buf) + if err := encoder.Encode(&config); err != nil { + return err + } + + return os.WriteFile(config.Path, buf.Bytes(), 0644) +} diff --git a/core/types/chat.go b/core/types/chat.go new file mode 100644 index 0000000..3827b86 --- /dev/null +++ b/core/types/chat.go @@ -0,0 +1,119 @@ +package types + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +// ApiRequest API 请求实体 +type ApiRequest struct { + Model string `json:"model,omitempty"` // 兼容百度文心一言 + Temperature float32 `json:"temperature"` + MaxTokens int `json:"max_tokens,omitempty"` // 兼容百度文心一言 + Stream bool `json:"stream"` + Messages []interface{} `json:"messages,omitempty"` + Prompt []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM + Tools []Tool `json:"tools,omitempty"` + Functions []interface{} `json:"functions,omitempty"` // 兼容中转平台 + + ToolChoice string `json:"tool_choice,omitempty"` + + Input map[string]interface{} `json:"input,omitempty"` //兼容阿里通义千问 + Parameters map[string]interface{} `json:"parameters,omitempty"` //兼容阿里通义千问 +} + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type ApiResponse struct { + Choices []ChoiceItem `json:"choices"` +} + +// ChoiceItem API 响应实体 +type ChoiceItem struct { + Delta Delta `json:"delta"` + FinishReason string `json:"finish_reason"` +} + +type Delta struct { + Role string `json:"role"` + Name string `json:"name"` + Content interface{} `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + FunctionCall struct { + Name string `json:"name,omitempty"` + Arguments string `json:"arguments,omitempty"` + } `json:"function_call,omitempty"` +} + +// ChatSession 聊天会话对象 +type ChatSession struct { + SessionId string `json:"session_id"` + ClientIP string `json:"client_ip"` // 客户端 IP + Username string `json:"username"` // 当前登录的 username + UserId uint `json:"user_id"` // 当前登录的 user ID + ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段 + Model ChatModel `json:"model"` // GPT 模型 +} + +type ChatModel struct { + Id uint `json:"id"` + Platform string `json:"platform"` + Name string `json:"name"` + Value string `json:"value"` + Power int `json:"power"` + MaxTokens int `json:"max_tokens"` // 最大响应长度 + MaxContext int `json:"max_context"` // 最大上下文长度 + Temperature float32 `json:"temperature"` // 模型温度 + KeyId int `json:"key_id"` // 绑定 API KEY +} + +type ApiError struct { + Error struct { + Message string + Type string + Param interface{} + Code string + } +} + +const PromptMsg = "prompt" // prompt message +const ReplyMsg = "reply" // reply message + +// PowerType 算力日志类型 +type PowerType int + +const ( + PowerRecharge = PowerType(1) // 充值 + PowerConsume = PowerType(2) // 消费 + PowerRefund = PowerType(3) // 任务(SD,MJ)执行失败,退款 + PowerInvite = PowerType(4) // 邀请奖励 + PowerReward = PowerType(5) // 众筹 + PowerGift = PowerType(6) // 系统赠送 +) + +func (t PowerType) String() string { + switch t { + case PowerRecharge: + return "充值" + case PowerConsume: + return "消费" + case PowerRefund: + return "退款" + case PowerReward: + return "众筹" + + } + return "其他" +} + +type PowerMark int + +const ( + PowerSub = PowerMark(0) + PowerAdd = PowerMark(1) +) diff --git a/core/types/client.go b/core/types/client.go new file mode 100644 index 0000000..5f65ac5 --- /dev/null +++ b/core/types/client.go @@ -0,0 +1,74 @@ +package types + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "errors" + "github.com/gorilla/websocket" + "sync" +) + +var ErrConClosed = errors.New("connection Closed") + +// WsClient websocket client +type WsClient struct { + Conn *websocket.Conn + lock sync.Mutex + mt int + Closed bool +} + +func NewWsClient(conn *websocket.Conn) *WsClient { + return &WsClient{ + Conn: conn, + lock: sync.Mutex{}, + mt: 2, // fixed bug for 'Invalid UTF-8 in text frame' + Closed: false, + } +} + +func (wc *WsClient) Send(message []byte) error { + wc.lock.Lock() + defer wc.lock.Unlock() + + if wc.Closed { + return ErrConClosed + } + + return wc.Conn.WriteMessage(wc.mt, message) +} + +func (wc *WsClient) SendJson(value interface{}) error { + wc.lock.Lock() + defer wc.lock.Unlock() + + if wc.Closed { + return ErrConClosed + } + return wc.Conn.WriteJSON(value) +} + +func (wc *WsClient) Receive() (int, []byte, error) { + if wc.Closed { + return 0, nil, ErrConClosed + } + + return wc.Conn.ReadMessage() +} + +func (wc *WsClient) Close() { + wc.lock.Lock() + defer wc.lock.Unlock() + + if wc.Closed { + return + } + + _ = wc.Conn.Close() + wc.Closed = true +} diff --git a/core/types/config.go b/core/types/config.go new file mode 100644 index 0000000..c988199 --- /dev/null +++ b/core/types/config.go @@ -0,0 +1,217 @@ +package types + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "fmt" +) + +type AppConfig struct { + Path string `toml:"-"` + Listen string + Session Session + AdminSession Session + ProxyURL string + MysqlDns string // mysql 连接地址 + StaticDir string // 静态资源目录 + StaticUrl string // 静态资源 URL + Redis RedisConfig // redis 连接信息 + ApiConfig ApiConfig // ChatPlus API authorization configs + SMS SMSConfig // send mobile message config + OSS OSSConfig // OSS config + MjProxyConfigs []MjProxyConfig // MJ proxy config + MjPlusConfigs []MjPlusConfig // MJ plus config + WeChatBot bool // 是否启用微信机器人 + SdConfigs []StableDiffusionConfig // sd AI draw service pool + + XXLConfig XXLConfig + AlipayConfig AlipayConfig + HuPiPayConfig HuPiPayConfig + SmtpConfig SmtpConfig // 邮件发送配置 + JPayConfig JPayConfig // payjs 支付配置 +} + +type SmtpConfig struct { + UseTls bool // 是否使用 TLS 发送 + Host string + Port int + AppName string // 应用名称 + From string // 发件人邮箱地址 + Password string // 发件人邮箱密码 +} + +type ApiConfig struct { + ApiURL string + AppId string + Token string +} + +type MjProxyConfig struct { + Enabled bool + ApiURL string // api 地址 + Mode string // 绘画模式,可选值:fast/turbo/relax + ApiKey string +} + +type StableDiffusionConfig struct { + Enabled bool + Model string // 模型名称 + ApiURL string + ApiKey string +} + +type MjPlusConfig struct { + Enabled bool // 如果启用了 MidJourney Plus,将会自动禁用原生的MidJourney服务 + ApiURL string // api 地址 + Mode string // 绘画模式,可选值:fast/turbo/relax + ApiKey string +} + +type AlipayConfig struct { + Enabled bool // 是否启用该支付通道 + SandBox bool // 是否沙盒环境 + AppId string // 应用 ID + UserId string // 支付宝用户 ID + PrivateKey string // 用户私钥文件路径 + PublicKey string // 用户公钥文件路径 + AlipayPublicKey string // 支付宝公钥文件路径 + RootCert string // Root 秘钥路径 + NotifyURL string // 异步通知回调 + ReturnURL string // 支付成功返回地址 +} + +type HuPiPayConfig struct { //虎皮椒第四方支付配置 + Enabled bool // 是否启用该支付通道 + Name string // 支付名称,如:wechat/alipay + AppId string // App ID + AppSecret string // app 密钥 + ApiURL string // 支付网关 + NotifyURL string // 异步通知回调 + ReturnURL string // 支付成功返回地址 +} + +// JPayConfig PayJs 支付配置 +type JPayConfig struct { + Enabled bool + Name string // 支付名称,默认 wechat + AppId string // 商户 ID + PrivateKey string // 私钥 + ApiURL string // API 网关 + NotifyURL string // 异步回调地址 + ReturnURL string // 支付成功返回地址 +} + +type XXLConfig struct { // XXL 任务调度配置 + Enabled bool + ServerAddr string + ExecutorIp string + ExecutorPort string + AccessToken string + RegistryKey string +} + +type RedisConfig struct { + Host string + Port int + Password string + DB int +} + +// LicenseKey 存储许可证书的 KEY +const LicenseKey = "Geek-AI-License" + +type License struct { + Key string `json:"key"` // 许可证书密钥 + MachineId string `json:"machine_id"` // 机器码 + ExpiredAt int64 `json:"expired_at"` // 过期时间 + IsActive bool `json:"is_active"` // 是否激活 + Configs LicenseConfig `json:"configs"` +} + +type LicenseConfig struct { + UserNum int `json:"user_num"` // 用户数量 + DeCopy bool `json:"de_copy"` // 去版权 +} + +func (c RedisConfig) Url() string { + return fmt.Sprintf("%s:%d", c.Host, c.Port) +} + +type Platform struct { + Name string `json:"name"` + Value string `json:"value"` + ChatURL string `json:"chat_url"` + ImgURL string `json:"img_url"` +} + +var OpenAI = Platform{ + Name: "OpenAI - GPT", + Value: "OpenAI", + ChatURL: "https://api.chat-plus.net/v1/chat/completions", + ImgURL: "https://api.chat-plus.net/v1/images/generations", +} +var Azure = Platform{ + Name: "微软 - Azure", + Value: "Azure", + ChatURL: "https://chat-bot-api.openai.azure.com/openai/deployments/{model}/chat/completions?api-version=2023-05-15", +} +var ChatGLM = Platform{ + Name: "智谱 - ChatGLM", + Value: "ChatGLM", + ChatURL: "https://open.bigmodel.cn/api/paas/v3/model-api/{model}/sse-invoke", +} +var Baidu = Platform{ + Name: "百度 - 文心大模型", + Value: "Baidu", + ChatURL: "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model}", +} +var XunFei = Platform{ + Name: "讯飞 - 星火大模型", + Value: "XunFei", + ChatURL: "wss://spark-api.xf-yun.com/{version}/chat", +} +var QWen = Platform{ + Name: "阿里 - 通义千问", + Value: "QWen", + ChatURL: "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation", +} + +type SystemConfig struct { + Title string `json:"title,omitempty"` + AdminTitle string `json:"admin_title,omitempty"` + Logo string `json:"logo,omitempty"` + InitPower int `json:"init_power,omitempty"` // 新用户注册赠送算力值 + DailyPower int `json:"daily_power,omitempty"` // 每日赠送算力 + InvitePower int `json:"invite_power,omitempty"` // 邀请新用户赠送算力值 + VipMonthPower int `json:"vip_month_power,omitempty"` // VIP 会员每月赠送的算力值 + + RegisterWays []string `json:"register_ways,omitempty"` // 注册方式:支持手机(mobile),邮箱注册(email),账号密码注册 + EnabledRegister bool `json:"enabled_register,omitempty"` // 是否开放注册 + + RewardImg string `json:"reward_img,omitempty"` // 众筹收款二维码地址 + EnabledReward bool `json:"enabled_reward,omitempty"` // 启用众筹功能 + PowerPrice float64 `json:"power_price,omitempty"` // 算力单价 + + OrderPayTimeout int `json:"order_pay_timeout,omitempty"` //订单支付超时时间 + VipInfoText string `json:"vip_info_text,omitempty"` // 会员页面充值说明 + DefaultModels []int `json:"default_models,omitempty"` // 默认开通的 AI 模型 + + MjPower int `json:"mj_power,omitempty"` // MJ 绘画消耗算力 + MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力 + SdPower int `json:"sd_power,omitempty"` // SD 绘画消耗算力 + DallPower int `json:"dall_power,omitempty"` // DALLE3 绘图消耗算力 + + WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址 + + EnableContext bool `json:"enable_context,omitempty"` + ContextDeep int `json:"context_deep,omitempty"` + + SdNegPrompt string `json:"sd_neg_prompt"` // SD 默认反向提示词 + + RandBg bool `json:"rand_bg"` // 前端首页是否启用随机背景 +} diff --git a/core/types/function.go b/core/types/function.go new file mode 100644 index 0000000..0897a7d --- /dev/null +++ b/core/types/function.go @@ -0,0 +1,27 @@ +package types + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +type ToolCall struct { + Type string `json:"type"` + Function struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + } `json:"function"` +} + +type Tool struct { + Type string `json:"type"` + Function Function `json:"function"` +} + +type Function struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]interface{} `json:"parameters"` +} diff --git a/core/types/locked_map.go b/core/types/locked_map.go new file mode 100644 index 0000000..5ae764b --- /dev/null +++ b/core/types/locked_map.go @@ -0,0 +1,70 @@ +package types + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "context" + "sync" +) + +type MKey interface { + string | int | uint +} +type MValue interface { + *WsClient | *ChatSession | context.CancelFunc | []Message +} +type LMap[K MKey, T MValue] struct { + lock sync.RWMutex + data map[K]T +} + +func NewLMap[K MKey, T MValue]() *LMap[K, T] { + return &LMap[K, T]{ + lock: sync.RWMutex{}, + data: make(map[K]T), + } +} + +func (m *LMap[K, T]) Put(key K, value T) { + m.lock.Lock() + defer m.lock.Unlock() + + m.data[key] = value +} + +func (m *LMap[K, T]) Get(key K) T { + m.lock.RLock() + defer m.lock.RUnlock() + + return m.data[key] +} + +func (m *LMap[K, T]) Has(key K) bool { + m.lock.RLock() + defer m.lock.RUnlock() + _, ok := m.data[key] + return ok +} + +func (m *LMap[K, T]) Delete(key K) { + m.lock.Lock() + defer m.lock.Unlock() + + delete(m.data, key) +} + +func (m *LMap[K, T]) ToList() []T { + m.lock.Lock() + defer m.lock.Unlock() + + var s = make([]T, 0) + for _, v := range m.data { + s = append(s, v) + } + return s +} diff --git a/core/types/order.go b/core/types/order.go new file mode 100644 index 0000000..90cc0cb --- /dev/null +++ b/core/types/order.go @@ -0,0 +1,24 @@ +package types + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +type OrderStatus int + +const ( + OrderNotPaid = OrderStatus(0) + OrderScanned = OrderStatus(1) // 已扫码 + OrderPaidSuccess = OrderStatus(2) +) + +type OrderRemark struct { + Days int `json:"days"` // 有效期 + Power int `json:"power"` // 增加算力点数 + Name string `json:"name"` // 产品名称 + Price float64 `json:"price"` + Discount float64 `json:"discount"` +} diff --git a/core/types/oss.go b/core/types/oss.go new file mode 100644 index 0000000..9bc93b4 --- /dev/null +++ b/core/types/oss.go @@ -0,0 +1,48 @@ +package types + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +type OSSConfig struct { + Active string + Local LocalStorageConfig + Minio MiniOssConfig + QiNiu QiNiuOssConfig + AliYun AliYunOssConfig +} +type MiniOssConfig struct { + Endpoint string + AccessKey string + AccessSecret string + Bucket string + SubDir string + UseSSL bool + Domain string +} + +type QiNiuOssConfig struct { + Zone string + AccessKey string + AccessSecret string + Bucket string + SubDir string + Domain string +} + +type AliYunOssConfig struct { + Endpoint string + AccessKey string + AccessSecret string + Bucket string + SubDir string + Domain string +} + +type LocalStorageConfig struct { + BasePath string + BaseURL string +} diff --git a/core/types/session.go b/core/types/session.go new file mode 100644 index 0000000..9108e51 --- /dev/null +++ b/core/types/session.go @@ -0,0 +1,20 @@ +package types + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +const LoginUserID = "LOGIN_USER_ID" +const LoginUserCache = "LOGIN_USER_CACHE" + +const UserAuthHeader = "Authorization" +const AdminAuthHeader = "Admin-Authorization" + +// Session configs struct +type Session struct { + SecretKey string + MaxAge int +} diff --git a/core/types/sms.go b/core/types/sms.go new file mode 100644 index 0000000..510e807 --- /dev/null +++ b/core/types/sms.go @@ -0,0 +1,33 @@ +package types + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +type SMSConfig struct { + Active string + Ali SmsConfigAli + Bao SmsConfigBao +} + +// SmsConfigAli 阿里云短信平台配置 +type SmsConfigAli struct { + AccessKey string + AccessSecret string + Product string + Domain string + Sign string // 短信签名 + CodeTempId string // 验证码短信模板 ID +} + +// SmsConfigBao 短信宝平台配置 +type SmsConfigBao struct { + Username string //短信宝平台注册的用户名 + Password string //短信宝平台注册的密码 + Domain string //域名 + Sign string // 短信签名 + CodeTemplate string // 验证码短信模板 匹配 +} diff --git a/core/types/task.go b/core/types/task.go new file mode 100644 index 0000000..6b6a364 --- /dev/null +++ b/core/types/task.go @@ -0,0 +1,82 @@ +package types + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +// TaskType 任务类别 +type TaskType string + +func (t TaskType) String() string { + return string(t) +} + +const ( + TaskImage = TaskType("image") + TaskBlend = TaskType("blend") + TaskSwapFace = TaskType("swapFace") + TaskUpscale = TaskType("upscale") + TaskVariation = TaskType("variation") +) + +// MjTask MidJourney 任务 +type MjTask struct { + Id uint `json:"id"` + TaskId string `json:"task_id"` + ImgArr []string `json:"img_arr"` + ChannelId string `json:"channel_id"` + SessionId string `json:"session_id"` + Type TaskType `json:"type"` + UserId int `json:"user_id"` + Prompt string `json:"prompt,omitempty"` + NegPrompt string `json:"neg_prompt,omitempty"` + Params string `json:"full_prompt"` + Index int `json:"index,omitempty"` + MessageId string `json:"message_id,omitempty"` + MessageHash string `json:"message_hash,omitempty"` + RetryCount int `json:"retry_count"` +} + +type SdTask struct { + Id int `json:"id"` // job 数据库ID + SessionId string `json:"session_id"` + Type TaskType `json:"type"` + UserId int `json:"user_id"` + Params SdTaskParams `json:"params"` + RetryCount int `json:"retry_count"` +} + +type SdTaskParams struct { + TaskId string `json:"task_id"` + Prompt string `json:"prompt"` // 提示词 + NegPrompt string `json:"neg_prompt"` // 反向提示词 + Steps int `json:"steps"` // 迭代步数,默认20 + Sampler string `json:"sampler"` // 采样器 + Scheduler string `json:"scheduler"` + FaceFix bool `json:"face_fix"` // 面部修复 + CfgScale float32 `json:"cfg_scale"` //引导系数,默认 7 + Seed int64 `json:"seed"` // 随机数种子 + Height int `json:"height"` + Width int `json:"width"` + HdFix bool `json:"hd_fix"` // 启用高清修复 + HdRedrawRate float32 `json:"hd_redraw_rate"` // 高清修复重绘幅度 + HdScale int `json:"hd_scale"` // 放大倍数 + HdScaleAlg string `json:"hd_scale_alg"` // 放大算法 + HdSteps int `json:"hd_steps"` // 高清修复迭代步数 +} + +// DallTask DALL-E task +type DallTask struct { + JobId uint `json:"job_id"` + UserId uint `json:"user_id"` + Prompt string `json:"prompt"` + N int `json:"n"` + Quality string `json:"quality"` + Size string `json:"size"` + Style string `json:"style"` + + Power int `json:"power"` +} diff --git a/core/types/web.go b/core/types/web.go new file mode 100644 index 0000000..08f5ee9 --- /dev/null +++ b/core/types/web.go @@ -0,0 +1,46 @@ +package types + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +// BizVo 业务返回 VO +type BizVo struct { + Code BizCode `json:"code"` + Page int `json:"page,omitempty"` + PageSize int `json:"page_size,omitempty"` + Total int `json:"total,omitempty"` + Message string `json:"message,omitempty"` + Data interface{} `json:"data,omitempty"` +} + +// WsMessage Websocket message +type WsMessage struct { + Type WsMsgType `json:"type"` // 消息类别,start, end, img + Content interface{} `json:"content"` +} +type WsMsgType string + +const ( + WsStart = WsMsgType("start") + WsMiddle = WsMsgType("middle") + WsEnd = WsMsgType("end") + WsErr = WsMsgType("error") +) + +type BizCode int + +const ( + Success = BizCode(0) + Failed = BizCode(1) + NotAuthorized = BizCode(400) // 未授权 + NotPermission = BizCode(403) // 没有权限 + + OkMsg = "Success" + ErrorMsg = "系统开小差了" + InvalidArgs = "非法参数或参数解析失败" + NoData = "No Data" +) diff --git a/fresh.conf b/fresh.conf new file mode 100644 index 0000000..aac77af --- /dev/null +++ b/fresh.conf @@ -0,0 +1,14 @@ +root: . +tmp_path: ./tmp +build_name: runner-build +build_log: runner-build-errors.log +valid_ext: .go, .tpl, .tmpl, .html +no_rebuild_ext: .tpl, .tmpl, .html, .js, .vue +ignored: assets, tmp, web, .git, .idea, test, data +build_delay: 600 +colors: 1 +log_color_main: cyan +log_color_build: yellow +log_color_runner: green +log_color_watcher: magenta +log_color_app: diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..3f08038 --- /dev/null +++ b/go.mod @@ -0,0 +1,117 @@ +module geekai + +go 1.21 + +toolchain go1.22.4 + +require ( + github.com/BurntSushi/toml v1.1.0 + github.com/aliyun/alibaba-cloud-sdk-go v1.62.405 + github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible + github.com/eatmoreapple/openwechat v1.2.1 + github.com/gin-gonic/gin v1.9.1 + github.com/go-redis/redis/v8 v8.11.5 + github.com/golang-jwt/jwt/v5 v5.0.0 + github.com/gorilla/websocket v1.5.0 + github.com/imroc/req/v3 v3.37.2 + github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230415042440-a5e3d8259ae0 + github.com/minio/minio-go/v7 v7.0.62 + github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480 + github.com/qiniu/go-sdk/v7 v7.17.1 + github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e + github.com/smartwalle/alipay/v3 v3.2.15 + go.uber.org/zap v1.23.0 + gopkg.in/natefinch/lumberjack.v2 v2.2.1 + gorm.io/driver/mysql v1.4.7 +) + +require github.com/xxl-job/xxl-job-executor-go v1.2.0 + +require ( + github.com/mojocn/base64Captcha v1.3.1 + github.com/shirou/gopsutil v3.21.11+incompatible + github.com/shopspring/decimal v1.3.1 + github.com/syndtr/goleveldb v1.0.0 + golang.org/x/image v0.0.0-20211028202545-6944b10bf410 +) + +require ( + github.com/go-ole/go-ole v1.2.6 // indirect + github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect + github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect + github.com/tklauser/go-sysconf v0.3.14 // indirect + github.com/tklauser/numcpus v0.8.0 // indirect + github.com/yusufpapurcu/wmi v1.2.4 // indirect + go.uber.org/mock v0.4.0 // indirect +) + +require ( + github.com/andybalholm/brotli v1.0.4 // indirect + github.com/bytedance/sonic v1.9.1 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dlclark/regexp2 v1.8.1 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/gabriel-vasile/mimetype v1.4.2 // indirect + github.com/gaukas/godicttls v0.0.3 // indirect + github.com/go-basic/ipv4 v1.0.0 // indirect + github.com/go-sql-driver/mysql v1.7.0 // indirect + github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect + github.com/goccy/go-json v0.10.2 // indirect + github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 // indirect + github.com/google/uuid v1.3.0 // indirect + github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af // indirect + github.com/klauspost/compress v1.16.7 // indirect + github.com/klauspost/cpuid/v2 v2.2.5 // indirect + github.com/minio/md5-simd v1.1.2 // indirect + github.com/minio/sha256-simd v1.0.1 // indirect + github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 + github.com/onsi/ginkgo/v2 v2.10.0 // indirect + github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b // indirect + github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/quic-go/qpack v0.4.0 // indirect + github.com/quic-go/quic-go v0.45.0 // indirect + github.com/refraction-networking/utls v1.3.2 // indirect + github.com/rs/xid v1.5.0 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect + github.com/smartwalle/ncrypto v1.0.2 // indirect + github.com/smartwalle/ngx v1.0.6 // indirect + github.com/smartwalle/nsign v1.0.8 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + go.uber.org/dig v1.16.1 // indirect + golang.org/x/arch v0.3.0 // indirect + golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect + golang.org/x/mod v0.17.0 // indirect + golang.org/x/net v0.25.0 // indirect + golang.org/x/sync v0.7.0 // indirect + golang.org/x/text v0.15.0 // indirect + golang.org/x/time v0.5.0 // indirect + golang.org/x/tools v0.21.0 // indirect + google.golang.org/protobuf v1.33.0 // indirect + gopkg.in/ini.v1 v1.67.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + +require ( + github.com/gin-contrib/sse v0.1.0 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.14.0 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/leodido/go-urn v1.2.4 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/ugorji/go/codec v1.2.11 // indirect + go.uber.org/atomic v1.9.0 // indirect + go.uber.org/fx v1.19.3 + go.uber.org/multierr v1.6.0 // indirect + golang.org/x/crypto v0.23.0 + golang.org/x/sys v0.20.0 // indirect + gorm.io/gorm v1.25.1 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..8c7ec59 --- /dev/null +++ b/go.sum @@ -0,0 +1,330 @@ +github.com/BurntSushi/toml v1.1.0 h1:ksErzDEI1khOiGPgpwuI7x2ebx/uXQNw7xJpn9Eq1+I= +github.com/BurntSushi/toml v1.1.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= +github.com/aliyun/alibaba-cloud-sdk-go v1.62.405 h1:cKNFQmeCQFN0WNfjScKoVrGi7vXxTVbkCvCqSrOf+P4= +github.com/aliyun/alibaba-cloud-sdk-go v1.62.405/go.mod h1:Api2AkmMgGaSUAhmk76oaFObkoeCPc/bKAqcyplPODs= +github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible h1:Sg/2xHwDrioHpxTN6WMiwbXTpUEinBpHsN7mG21Rc2k= +github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible/go.mod h1:T/Aws4fEfogEE9v+HPhhw+CntffsBHJ8nXQCwKr0/g8= +github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= +github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A= +github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= +github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= +github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= +github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0= +github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/eatmoreapple/openwechat v1.2.1 h1:ez4oqF/Y2NSEX/DbPV8lvj7JlfkYqvieeo4awx5lzfU= +github.com/eatmoreapple/openwechat v1.2.1/go.mod h1:61HOzTyvLobGdgWhL68jfGNwTJEv0mhQ1miCXQrvWU8= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= +github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= +github.com/gaukas/godicttls v0.0.3 h1:YNDIf0d9adcxOijiLrEzpfZGAkNwLRzPaG6OjU7EITk= +github.com/gaukas/godicttls v0.0.3/go.mod h1:l6EenT4TLWgTdwslVb4sEMOCf7Bv0JAK67deKr9/NCI= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= +github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= +github.com/go-basic/ipv4 v1.0.0 h1:gjyFAa1USC1hhXTkPOwBWDPfMcUaIM+tvo1XzV9EZxs= +github.com/go-basic/ipv4 v1.0.0/go.mod h1:etLBnaxbidQfuqE6wgZQfs38nEWNmzALkxDZe4xY8Dg= +github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= +github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= +github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= +github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= +github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= +github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.8.0/go.mod h1:9JhgTzTaE31GZDpH/HSvHiRJrJ3iKAgqqH0Bl/Ocjdk= +github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= +github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= +github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= +github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= +github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= +github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A= +github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE= +github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g= +github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= +github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w= +github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9SN1TigNLn9ZnF3W4SYRKq2gAHs= +github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= +github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/imroc/req/v3 v3.37.2 h1:vEemuA0cq9zJ6lhe+mSRhsZm951bT0CdiSH47+KTn6I= +github.com/imroc/req/v3 v3.37.2/go.mod h1:DECzjVIrj6jcUr5n6e+z0ygmCO93rx4Jy0RjOEe1YCI= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM= +github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= +github.com/json-iterator/go v1.1.5/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/compress v1.16.7 h1:2mk3MPGNzKyxErAw8YaohYh69+pa4sIQSC0fPGCFR9I= +github.com/klauspost/compress v1.16.7/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.5 h1:0E5MSMDEoAulmXNFquVs//DdoomxaoTY1kUhbc/qbZg= +github.com/klauspost/cpuid/v2 v2.2.5/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= +github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= +github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= +github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230415042440-a5e3d8259ae0 h1:LgmjED/yQILqmUED4GaXjrINWe7YJh4HM6z2EvEINPs= +github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230415042440-a5e3d8259ae0/go.mod h1:C5LA5UO2ZXJrLaPLYtE1wUJMiyd/nwWaCO5cw/2pSHs= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34= +github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM= +github.com/minio/minio-go/v7 v7.0.62 h1:qNYsFZHEzl+NfH8UxW4jpmlKav1qUAgfY30YNRneVhc= +github.com/minio/minio-go/v7 v7.0.62/go.mod h1:Q6X7Qjb7WMhvG65qKf4gUgA5XaiSox74kR1uAEjxRS4= +github.com/minio/sha256-simd v1.0.1 h1:6kaan5IFmwTNynnKKpDHe6FWHohJOHhCPchzK49dzMM= +github.com/minio/sha256-simd v1.0.1/go.mod h1:Pz6AKMiUdngCLpeTL/RJY1M9rUuPMYujV5xJjtbRSN8= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/mojocn/base64Captcha v1.3.1 h1:2Wbkt8Oc8qjmNJ5GyOfSo4tgVQPsbKMftqASnq8GlT0= +github.com/mojocn/base64Captcha v1.3.1/go.mod h1:wAQCKEc5bDujxKRmbT6/vTnTt5CjStQ8bRfPWUuz/iY= +github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= +github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= +github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= +github.com/onsi/ginkgo/v2 v2.10.0 h1:sfUl4qgLdvkChZrWCYndY2EAu9BRIw1YphNAzy1VNWs= +github.com/onsi/ginkgo/v2 v2.10.0/go.mod h1:UDQOh5wbQUlMnkLfVaIUMtQ1Vus92oM+P2JX1aulgcE= +github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/onsi/gomega v1.27.7 h1:fVih9JD6ogIiHUN6ePK7HJidyEDpWGVB5mzM7cWNXoU= +github.com/onsi/gomega v1.27.7/go.mod h1:1p8OOlwo2iUUDsHnOrjE5UKYJ+e3W8eQ3qSlRahPmr4= +github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b h1:FfH+VrHHk6Lxt9HdVS0PXzSXFyS2NbZKXv33FYPol0A= +github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b/go.mod h1:AC62GU6hc0BrNm+9RK9VSiwa/EUe1bkIeFORAMcHvJU= +github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= +github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480 h1:IFhPCcB0/HtnEN+ZoUGDT55YgFCymbFJ15kXqs3nv5w= +github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480/go.mod h1:BijIqAP84FMYC4XbdJgjyMpiSjusU8x0Y0W9K2t0QtU= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/qiniu/dyn v1.3.0/go.mod h1:E8oERcm8TtwJiZvkQPbcAh0RL8jO1G0VXJMW3FAWdkk= +github.com/qiniu/go-sdk/v7 v7.17.1 h1:UoQv7fBKtzAiD1qZPIvTy62Se48YLKxcCYP9nAwWMa0= +github.com/qiniu/go-sdk/v7 v7.17.1/go.mod h1:nqoYCNo53ZlGA521RvRethvxUDvXKt4gtYXOwye868w= +github.com/qiniu/x v1.10.5/go.mod h1:03Ni9tj+N2h2aKnAz+6N0Xfl8FwMEDRC2PAlxekASDs= +github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= +github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= +github.com/quic-go/quic-go v0.45.0 h1:OHmkQGM37luZITyTSu6ff03HP/2IrwDX1ZFiNEhSFUE= +github.com/quic-go/quic-go v0.45.0/go.mod h1:1dLehS7TIR64+vxGR70GDcatWTOtMX2PUtnKsjbTurI= +github.com/refraction-networking/utls v1.3.2 h1:o+AkWB57mkcoW36ET7uJ002CpBWHu0KPxi6vzxvPnv8= +github.com/refraction-networking/utls v1.3.2/go.mod h1:fmoaOww2bxzzEpIKOebIsnBvjQpqP7L2vcm/9KUfm/E= +github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= +github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= +github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= +github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= +github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= +github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= +github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= +github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= +github.com/smartwalle/alipay/v3 v3.2.15 h1:3fvFJnINKKAOXHR/Iv20k1Z7KJ+nOh3oK214lELPqG8= +github.com/smartwalle/alipay/v3 v3.2.15/go.mod h1:niTNB609KyUYuAx9Bex/MawEjv2yPx4XOjxSAkqmGjE= +github.com/smartwalle/ncrypto v1.0.2 h1:pTAhCqtPCMhpOwFXX+EcMdR6PNzruBNoGQrN2S1GbGI= +github.com/smartwalle/ncrypto v1.0.2/go.mod h1:Dwlp6sfeNaPMnOxMNayMTacvC5JGEVln3CVdiVDgbBk= +github.com/smartwalle/ngx v1.0.6 h1:JPNqNOIj+2nxxFtrSkJO+vKJfeNUSEQueck/Wworjps= +github.com/smartwalle/ngx v1.0.6/go.mod h1:mx/nz2Pk5j+RBs7t6u6k22MPiBG/8CtOMpCnALIG8Y0= +github.com/smartwalle/nsign v1.0.8 h1:78KWtwKPrdt4Xsn+tNEBVxaTLIJBX9YRX0ZSrMUeuHo= +github.com/smartwalle/nsign v1.0.8/go.mod h1:eY6I4CJlyNdVMP+t6z1H6Jpd4m5/V+8xi44ufSTxXgc= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= +github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE= +github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ= +github.com/tklauser/go-sysconf v0.3.14 h1:g5vzr9iPFFz24v2KZXs/pvpvh8/V9Fw6vQK5ZZb78yU= +github.com/tklauser/go-sysconf v0.3.14/go.mod h1:1ym4lWMLUOhuBOPGtRcJm7tEGX4SCYNEEEtghGG/8uY= +github.com/tklauser/numcpus v0.8.0 h1:Mx4Wwe/FjZLeQsK/6kt2EOepwwSl7SmJrK5bV/dXYgY= +github.com/tklauser/numcpus v0.8.0/go.mod h1:ZJZlAY+dmR4eut8epnzf0u/VwodKmryxR8txiloSqBE= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/uber/jaeger-client-go v2.30.0+incompatible h1:D6wyKGCecFaSRUpo8lCVbaOOb6ThwMmTEbhRwtKR97o= +github.com/uber/jaeger-client-go v2.30.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk= +github.com/uber/jaeger-lib v2.4.1+incompatible h1:td4jdvLcExb4cBISKIpHuGoVXh+dVKhn2Um6rjCsSsg= +github.com/uber/jaeger-lib v2.4.1+incompatible/go.mod h1:ComeNDZlWwrWnDv8aPp0Ba6+uUTzImX/AauajbLI56U= +github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= +github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +github.com/xxl-job/xxl-job-executor-go v1.2.0 h1:MTl2DpwrK2+hNjRRks2k7vB3oy+3onqm9OaSarneeLQ= +github.com/xxl-job/xxl-job-executor-go v1.2.0/go.mod h1:bUFhz/5Irp9zkdYk5MxhQcDDT6LlZrI8+rv5mHtQ1mo= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= +github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/dig v1.16.1 h1:+alNIBsl0qfY0j6epRubp/9obgtrObRAc5aD+6jbWY8= +go.uber.org/dig v1.16.1/go.mod h1:557JTAUZT5bUK0SvCwikmLPPtdQhfvLYtO5tJgQSbnk= +go.uber.org/fx v1.19.3 h1:YqMRE4+2IepTYCMOvXqQpRa+QAVdiSTnsHU4XNWBceA= +go.uber.org/fx v1.19.3/go.mod h1:w2HrQg26ql9fLK7hlBiZ6JsRUKV+Lj/atT1KCjT8YhM= +go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI= +go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= +go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= +go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= +go.uber.org/zap v1.23.0 h1:OjGQ5KQDEUawVHxNwQgPpiypGHOxo2mNZsOqTak4fFY= +go.uber.org/zap v1.23.0/go.mod h1:D+nX8jyLsMHMYrln8A0rJjFt/T/9/bGgIhAqxv5URuY= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= +golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= +golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= +golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= +golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= +golang.org/x/image v0.0.0-20211028202545-6944b10bf410 h1:hTftEOvwiOq2+O8k2D5/Q7COC7k5Qcrgc2TFURJYnvQ= +golang.org/x/image v0.0.0-20211028202545-6944b10bf410/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= +golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw= +golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/ini.v1 v1.66.2/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= +gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= +gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/mysql v1.4.7 h1:rY46lkCspzGHn7+IYsNpSfEv9tA+SU4SkkB+GFX125Y= +gorm.io/driver/mysql v1.4.7/go.mod h1:SxzItlnT1cb6e1e4ZRpgJN2VYtcqJgqnHxWr4wsP8oc= +gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= +gorm.io/gorm v1.25.1 h1:nsSALe5Pr+cM3V1qwwQ7rOkw+6UeLrX5O4v3llhHa64= +gorm.io/gorm v1.25.1/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/handler/admin/admin_handler.go b/handler/admin/admin_handler.go new file mode 100644 index 0000000..3fceab8 --- /dev/null +++ b/handler/admin/admin_handler.go @@ -0,0 +1,279 @@ +package admin + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/core/types" + "geekai/handler" + logger2 "geekai/logger" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + "context" + "fmt" + "github.com/go-redis/redis/v8" + "github.com/golang-jwt/jwt/v5" + "github.com/mojocn/base64Captcha" + "time" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +var logger = logger2.GetLogger() + +// Manager 管理员 +type Manager struct { + Username string `json:"username"` + Password string `json:"password"` + Captcha string `json:"captcha"` // 验证码 + CaptchaId string `json:"captcha_id"` // 验证码id +} + +const SuperManagerID = 1 + +type ManagerHandler struct { + handler.BaseHandler + redis *redis.Client +} + +func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client) *ManagerHandler { + return &ManagerHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, redis: client} +} + +// Login 登录 +func (h *ManagerHandler) Login(c *gin.Context) { + var data Manager + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + // add captcha + if !base64Captcha.DefaultMemStore.Verify(data.CaptchaId, data.Captcha, true) { + resp.ERROR(c, "验证码错误!") + return + } + + var manager model.AdminUser + res := h.DB.Model(&model.AdminUser{}).Where("username = ?", data.Username).First(&manager) + if res.Error != nil { + resp.ERROR(c, "请检查用户名或者密码是否填写正确") + return + } + password := utils.GenPassword(data.Password, manager.Salt) + if password != manager.Password { + resp.ERROR(c, "用户名或密码错误") + return + } + + // 超级管理员默认是ID:1 + if manager.Id != SuperManagerID && manager.Status == false { + resp.ERROR(c, "该用户已被禁止登录,请联系超级管理员") + return + } + + // 创建 token + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "user_id": manager.Id, + "expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(), + }) + tokenString, err := token.SignedString([]byte(h.App.Config.AdminSession.SecretKey)) + if err != nil { + resp.ERROR(c, "Failed to generate token, "+err.Error()) + return + } + // 保存到 redis + key := fmt.Sprintf("admin/%d", manager.Id) + if _, err := h.redis.Set(context.Background(), key, tokenString, 0).Result(); err != nil { + resp.ERROR(c, "error with save token: "+err.Error()) + return + } + + // 更新最后登录时间和IP + manager.LastLoginIp = c.ClientIP() + manager.LastLoginAt = time.Now().Unix() + h.DB.Updates(&manager) + + var result = struct { + IsSuperAdmin bool `json:"is_super_admin"` + Token string `json:"token"` + }{ + IsSuperAdmin: manager.Id == 1, + Token: tokenString, + } + + resp.SUCCESS(c, result) +} + +// Logout 注销 +func (h *ManagerHandler) Logout(c *gin.Context) { + key := h.GetUserKey(c) + if _, err := h.redis.Del(c, key).Result(); err != nil { + logger.Error("error with delete session: ", err) + } else { + resp.SUCCESS(c) + } +} + +// Session 会话检测 +func (h *ManagerHandler) Session(c *gin.Context) { + id := h.GetLoginUserId(c) + key := fmt.Sprintf("admin/%d", id) + if _, err := h.redis.Get(context.Background(), key).Result(); err != nil { + resp.NotAuth(c) + return + } + var manager model.AdminUser + res := h.DB.Where("id", id).First(&manager) + if res.Error != nil { + resp.NotAuth(c) + return + } + + resp.SUCCESS(c, manager) +} + +// List 数据列表 +func (h *ManagerHandler) List(c *gin.Context) { + var items []model.AdminUser + res := h.DB.Find(&items) + if res.Error != nil { + resp.ERROR(c, res.Error.Error()) + return + } + + users := make([]vo.AdminUser, 0) + for _, item := range items { + var u vo.AdminUser + err := utils.CopyObject(item, &u) + if err != nil { + continue + } + u.Id = item.Id + u.CreatedAt = item.CreatedAt.Unix() + users = append(users, u) + } + + resp.SUCCESS(c, users) + +} + +func (h *ManagerHandler) Save(c *gin.Context) { + var data struct { + Username string `json:"username"` + Password string `json:"password"` + Status bool `json:"status"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + var user model.AdminUser + res := h.DB.Where("username", data.Username).First(&user) + if res.Error == nil { + resp.ERROR(c, "用户名已存在") + return + } + + // 生成密码 + salt := utils.RandString(8) + password := utils.GenPassword(data.Password, salt) + res = h.DB.Save(&model.AdminUser{ + Username: data.Username, + Password: password, + Salt: salt, + Status: data.Status, + }) + if res.Error != nil { + resp.ERROR(c, "failed with update database") + return + } + + resp.SUCCESS(c) +} + +// Remove 删除管理员 +func (h *ManagerHandler) Remove(c *gin.Context) { + id := h.GetInt(c, "id", 0) + if id <= 0 { + resp.ERROR(c, types.InvalidArgs) + return + } + + if id == SuperManagerID { + resp.ERROR(c, "超级管理员不能删除") + return + } + + res := h.DB.Where("id", id).Delete(&model.AdminUser{}) + if res.Error != nil { + resp.ERROR(c, res.Error.Error()) + return + } + + resp.SUCCESS(c) +} + +// Enable 启用/禁用 +func (h *ManagerHandler) Enable(c *gin.Context) { + var data struct { + Id uint `json:"id"` + Enabled bool `json:"enabled"` + } + + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + res := h.DB.Model(&model.AdminUser{}).Where("id", data.Id).UpdateColumn("status", data.Enabled) + if res.Error != nil { + resp.ERROR(c, res.Error.Error()) + return + } + resp.SUCCESS(c) +} + +// ResetPass 重置密码 +func (h *ManagerHandler) ResetPass(c *gin.Context) { + id := h.GetLoginUserId(c) + if id != SuperManagerID { + resp.ERROR(c, "只有超级管理员能够进行该操作") + return + } + + var data struct { + Id int `json:"id"` + Password string `json:"password"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + var user model.AdminUser + res := h.DB.Where("id", data.Id).First(&user) + if res.Error != nil { + resp.ERROR(c, res.Error.Error()) + return + } + + password := utils.GenPassword(data.Password, user.Salt) + user.Password = password + res = h.DB.Updates(&user) + if res.Error != nil { + resp.ERROR(c, res.Error.Error()) + return + } + + resp.SUCCESS(c) +} diff --git a/handler/admin/api_key_handler.go b/handler/admin/api_key_handler.go new file mode 100644 index 0000000..f412c03 --- /dev/null +++ b/handler/admin/api_key_handler.go @@ -0,0 +1,147 @@ +package admin + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/core/types" + "geekai/handler" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +type ApiKeyHandler struct { + handler.BaseHandler +} + +func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler { + return &ApiKeyHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}} +} + +func (h *ApiKeyHandler) Save(c *gin.Context) { + var data struct { + Id uint `json:"id"` + Platform string `json:"platform"` + Name string `json:"name"` + Type string `json:"type"` + Value string `json:"value"` + ApiURL string `json:"api_url"` + Enabled bool `json:"enabled"` + ProxyURL string `json:"proxy_url"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + apiKey := model.ApiKey{} + if data.Id > 0 { + h.DB.Find(&apiKey, data.Id) + } + apiKey.Platform = data.Platform + apiKey.Value = data.Value + apiKey.Type = data.Type + apiKey.ApiURL = data.ApiURL + apiKey.Enabled = data.Enabled + apiKey.ProxyURL = data.ProxyURL + apiKey.Name = data.Name + res := h.DB.Save(&apiKey) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败!") + return + } + + var keyVo vo.ApiKey + err := utils.CopyObject(apiKey, &keyVo) + if err != nil { + resp.ERROR(c, "数据拷贝失败!") + return + } + keyVo.Id = apiKey.Id + keyVo.CreatedAt = apiKey.CreatedAt.Unix() + resp.SUCCESS(c, keyVo) +} + +func (h *ApiKeyHandler) List(c *gin.Context) { + status := h.GetBool(c, "status") + t := h.GetTrim(c, "type") + platform := h.GetTrim(c, "platform") + + session := h.DB.Session(&gorm.Session{}) + if status { + session = session.Where("enabled", true) + } + if t != "" { + session = session.Where("type", t) + } + if platform != "" { + session = session.Where("platform", platform) + } + + var items []model.ApiKey + var keys = make([]vo.ApiKey, 0) + res := session.Find(&items) + if res.Error == nil { + for _, item := range items { + var key vo.ApiKey + err := utils.CopyObject(item, &key) + if err == nil { + key.Id = item.Id + key.CreatedAt = item.CreatedAt.Unix() + key.UpdatedAt = item.UpdatedAt.Unix() + keys = append(keys, key) + } else { + logger.Error(err) + } + } + } + resp.SUCCESS(c, keys) +} + +func (h *ApiKeyHandler) Set(c *gin.Context) { + var data struct { + Id uint `json:"id"` + Filed string `json:"filed"` + Value interface{} `json:"value"` + } + + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + res := h.DB.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败!") + return + } + resp.SUCCESS(c) +} + +func (h *ApiKeyHandler) Remove(c *gin.Context) { + id := h.GetInt(c, "id", 0) + if id <= 0 { + resp.ERROR(c, types.InvalidArgs) + return + } + + res := h.DB.Where("id", id).Delete(&model.ApiKey{}) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败!") + return + } + resp.SUCCESS(c) +} diff --git a/handler/admin/captcha_handler.go b/handler/admin/captcha_handler.go new file mode 100644 index 0000000..e6b81cb --- /dev/null +++ b/handler/admin/captcha_handler.go @@ -0,0 +1,46 @@ +package admin + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/handler" + "geekai/utils/resp" + "github.com/gin-gonic/gin" + "github.com/mojocn/base64Captcha" +) + +type CaptchaHandler struct { + handler.BaseHandler +} + +func NewCaptchaHandler(app *core.AppServer) *CaptchaHandler { + return &CaptchaHandler{BaseHandler: handler.BaseHandler{App: app}} +} + +type CaptchaVo struct { + CaptchaId string `json:"captcha_id"` + PicPath string `json:"pic_path"` +} + +// GetCaptcha 获取验证码 +func (h *CaptchaHandler) GetCaptcha(c *gin.Context) { + var captchaVo CaptchaVo + driver := base64Captcha.NewDriverDigit(48, 130, 4, 0.4, 10) + cp := base64Captcha.NewCaptcha(driver, base64Captcha.DefaultMemStore) + // b64s是图片的base64编码 + id, b64s, err := cp.Generate() + if err != nil { + resp.ERROR(c, "生成验证码错误!") + return + } + captchaVo.CaptchaId = id + captchaVo.PicPath = b64s + + resp.SUCCESS(c, captchaVo) +} diff --git a/handler/admin/chat_handler.go b/handler/admin/chat_handler.go new file mode 100644 index 0000000..f51cf85 --- /dev/null +++ b/handler/admin/chat_handler.go @@ -0,0 +1,269 @@ +package admin + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/core/types" + "geekai/handler" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +type ChatHandler struct { + handler.BaseHandler +} + +func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler { + return &ChatHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} +} + +type chatItemVo struct { + Username string `json:"username"` + UserId uint `json:"user_id"` + ChatId string `json:"chat_id"` + Title string `json:"title"` + Role vo.ChatRole `json:"role"` + Model string `json:"model"` + Token int `json:"token"` + CreatedAt int64 `json:"created_at"` + MsgNum int `json:"msg_num"` // 消息数量 +} + +func (h *ChatHandler) List(c *gin.Context) { + var data struct { + Title string `json:"title"` + UserId uint `json:"user_id"` + Model string `json:"model"` + CreateAt []string `json:"created_time"` + Page int `json:"page"` + PageSize int `json:"page_size"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + session := h.DB.Session(&gorm.Session{}) + if data.Title != "" { + session = session.Where("title LIKE ?", "%"+data.Title+"%") + } + if data.UserId > 0 { + session = session.Where("user_id = ?", data.UserId) + } + if data.Model != "" { + session = session.Where("model = ?", data.Model) + } + if len(data.CreateAt) == 2 { + start := utils.Str2stamp(data.CreateAt[0] + " 00:00:00") + end := utils.Str2stamp(data.CreateAt[1] + " 00:00:00") + session = session.Where("created_at >= ? AND created_at <= ?", start, end) + } + + var total int64 + session.Model(&model.ChatItem{}).Count(&total) + var items []model.ChatItem + var list = make([]chatItemVo, 0) + offset := (data.Page - 1) * data.PageSize + res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items) + if res.Error == nil { + userIds := make([]uint, 0) + chatIds := make([]string, 0) + roleIds := make([]uint, 0) + for _, item := range items { + userIds = append(userIds, item.UserId) + chatIds = append(chatIds, item.ChatId) + roleIds = append(roleIds, item.RoleId) + } + var messages []model.ChatMessage + var users []model.User + var roles []model.ChatRole + h.DB.Where("chat_id IN ?", chatIds).Find(&messages) + h.DB.Where("id IN ?", userIds).Find(&users) + h.DB.Where("id IN ?", roleIds).Find(&roles) + + tokenMap := make(map[string]int) + userMap := make(map[uint]string) + msgMap := make(map[string]int) + roleMap := make(map[uint]vo.ChatRole) + for _, msg := range messages { + tokenMap[msg.ChatId] += msg.Tokens + msgMap[msg.ChatId] += 1 + } + for _, user := range users { + userMap[user.Id] = user.Username + } + for _, r := range roles { + var roleVo vo.ChatRole + err := utils.CopyObject(r, &roleVo) + if err != nil { + continue + } + roleMap[r.Id] = roleVo + } + for _, item := range items { + list = append(list, chatItemVo{ + UserId: item.UserId, + Username: userMap[item.UserId], + ChatId: item.ChatId, + Title: item.Title, + Model: item.Model, + Token: tokenMap[item.ChatId], + MsgNum: msgMap[item.ChatId], + Role: roleMap[item.RoleId], + CreatedAt: item.CreatedAt.Unix(), + }) + } + } + resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list)) +} + +type chatMessageVo struct { + Id uint `json:"id"` + UserId uint `json:"user_id"` + Username string `json:"username"` + Content string `json:"content"` + Type string `json:"type"` + Model string `json:"model"` + Token int `json:"token"` + Icon string `json:"icon"` + CreatedAt int64 `json:"created_at"` +} + +// Messages 读取聊天记录列表 +func (h *ChatHandler) Messages(c *gin.Context) { + var data struct { + UserId uint `json:"user_id"` + Content string `json:"content"` + Model string `json:"model"` + CreateAt []string `json:"created_time"` + Page int `json:"page"` + PageSize int `json:"page_size"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + session := h.DB.Session(&gorm.Session{}) + if data.Content != "" { + session = session.Where("content LIKE ?", "%"+data.Content+"%") + } + if data.UserId > 0 { + session = session.Where("user_id = ?", data.UserId) + } + if data.Model != "" { + session = session.Where("model = ?", data.Model) + } + if len(data.CreateAt) == 2 { + start := utils.Str2stamp(data.CreateAt[0] + " 00:00:00") + end := utils.Str2stamp(data.CreateAt[1] + " 00:00:00") + session = session.Where("created_at >= ? AND created_at <= ?", start, end) + } + + var total int64 + session.Model(&model.ChatMessage{}).Count(&total) + var items []model.ChatMessage + var list = make([]chatMessageVo, 0) + offset := (data.Page - 1) * data.PageSize + res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items) + if res.Error == nil { + userIds := make([]uint, 0) + for _, item := range items { + userIds = append(userIds, item.UserId) + } + var users []model.User + h.DB.Where("id IN ?", userIds).Find(&users) + userMap := make(map[uint]string) + for _, user := range users { + userMap[user.Id] = user.Username + } + for _, item := range items { + list = append(list, chatMessageVo{ + Id: item.Id, + UserId: item.UserId, + Username: userMap[item.UserId], + Content: item.Content, + Model: item.Model, + Token: item.Tokens, + Icon: item.Icon, + Type: item.Type, + CreatedAt: item.CreatedAt.Unix(), + }) + } + } + resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list)) +} + +// History 获取聊天历史记录 +func (h *ChatHandler) History(c *gin.Context) { + chatId := c.Query("chat_id") // 会话 ID + var items []model.ChatMessage + var messages = make([]vo.HistoryMessage, 0) + res := h.DB.Where("chat_id = ?", chatId).Find(&items) + if res.Error != nil { + resp.ERROR(c, "No history message") + return + } else { + for _, item := range items { + var v vo.HistoryMessage + err := utils.CopyObject(item, &v) + v.CreatedAt = item.CreatedAt.Unix() + v.UpdatedAt = item.UpdatedAt.Unix() + if err == nil { + messages = append(messages, v) + } + } + } + + resp.SUCCESS(c, messages) +} + +// RemoveChat 删除对话 +func (h *ChatHandler) RemoveChat(c *gin.Context) { + chatId := h.GetTrim(c, "chat_id") + if chatId == "" { + resp.ERROR(c, "请传入 ChatId") + return + } + + tx := h.DB.Begin() + // 删除聊天记录 + res := tx.Unscoped().Debug().Where("chat_id = ?", chatId).Delete(&model.ChatMessage{}) + if res.Error != nil { + resp.ERROR(c, "failed to remove chat message") + return + } + + // 删除对话 + res = tx.Unscoped().Where("chat_id = ?", chatId).Delete(model.ChatItem{}) + if res.Error != nil { + tx.Rollback() // 回滚 + resp.ERROR(c, "failed to remove chat") + return + } + + tx.Commit() + resp.SUCCESS(c) +} + +// RemoveMessage 删除聊天记录 +func (h *ChatHandler) RemoveMessage(c *gin.Context) { + id := h.GetInt(c, "id", 0) + tx := h.DB.Unscoped().Where("id = ?", id).Delete(&model.ChatMessage{}) + if tx.Error != nil { + logger.Error("error with update database:", tx.Error) + resp.ERROR(c, "更新数据库失败!") + return + } + resp.SUCCESS(c) +} diff --git a/handler/admin/chat_model_handler.go b/handler/admin/chat_model_handler.go new file mode 100644 index 0000000..74f187c --- /dev/null +++ b/handler/admin/chat_model_handler.go @@ -0,0 +1,192 @@ +package admin + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/core/types" + "geekai/handler" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +type ChatModelHandler struct { + handler.BaseHandler +} + +func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler { + return &ChatModelHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} +} + +func (h *ChatModelHandler) Save(c *gin.Context) { + var data struct { + Id uint `json:"id"` + Name string `json:"name"` + Value string `json:"value"` + Enabled bool `json:"enabled"` + SortNum int `json:"sort_num"` + Open bool `json:"open"` + Platform string `json:"platform"` + Power int `json:"power"` + MaxTokens int `json:"max_tokens"` // 最大响应长度 + MaxContext int `json:"max_context"` // 最大上下文长度 + Temperature float32 `json:"temperature"` // 模型温度 + KeyId int `json:"key_id,omitempty"` + CreatedAt int64 `json:"created_at"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + item := model.ChatModel{ + Platform: data.Platform, + Name: data.Name, + Value: data.Value, + Enabled: data.Enabled, + SortNum: data.SortNum, + Open: data.Open, + MaxTokens: data.MaxTokens, + MaxContext: data.MaxContext, + Temperature: data.Temperature, + KeyId: data.KeyId, + Power: data.Power} + var res *gorm.DB + if data.Id > 0 { + item.Id = data.Id + res = h.DB.Select("*").Omit("created_at").Updates(&item) + } else { + res = h.DB.Create(&item) + } + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败!") + return + } + + var itemVo vo.ChatModel + err := utils.CopyObject(item, &itemVo) + if err != nil { + resp.ERROR(c, "数据拷贝失败!") + return + } + itemVo.Id = item.Id + itemVo.CreatedAt = item.CreatedAt.Unix() + resp.SUCCESS(c, itemVo) +} + +// List 模型列表 +func (h *ChatModelHandler) List(c *gin.Context) { + session := h.DB.Session(&gorm.Session{}) + enable := h.GetBool(c, "enable") + platform := h.GetTrim(c, "platform") + if enable { + session = session.Where("enabled", enable) + } + if platform != "" { + session = session.Where("platform", platform) + } + var items []model.ChatModel + var cms = make([]vo.ChatModel, 0) + res := session.Order("sort_num ASC").Find(&items) + if res.Error != nil { + resp.SUCCESS(c, cms) + return + } + + // initialize key name + keyIds := make([]int, 0) + for _, v := range items { + keyIds = append(keyIds, v.KeyId) + } + var keys []model.ApiKey + keyMap := make(map[uint]string) + h.DB.Where("id IN ?", keyIds).Find(&keys) + for _, v := range keys { + keyMap[v.Id] = v.Name + } + for _, item := range items { + var cm vo.ChatModel + err := utils.CopyObject(item, &cm) + if err == nil { + cm.Id = item.Id + cm.CreatedAt = item.CreatedAt.Unix() + cm.UpdatedAt = item.UpdatedAt.Unix() + cm.KeyName = keyMap[uint(item.KeyId)] + cms = append(cms, cm) + } else { + logger.Error(err) + } + } + resp.SUCCESS(c, cms) +} + +func (h *ChatModelHandler) Set(c *gin.Context) { + var data struct { + Id uint `json:"id"` + Filed string `json:"filed"` + Value interface{} `json:"value"` + } + + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + res := h.DB.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败!") + return + } + resp.SUCCESS(c) +} + +func (h *ChatModelHandler) Sort(c *gin.Context) { + var data struct { + Ids []uint `json:"ids"` + Sorts []int `json:"sorts"` + } + + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + for index, id := range data.Ids { + res := h.DB.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败!") + return + } + } + + resp.SUCCESS(c) +} + +func (h *ChatModelHandler) Remove(c *gin.Context) { + id := h.GetInt(c, "id", 0) + if id <= 0 { + resp.ERROR(c, types.InvalidArgs) + return + } + + res := h.DB.Where("id = ?", id).Delete(&model.ChatModel{}) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败!") + return + } + resp.SUCCESS(c) +} diff --git a/handler/admin/chat_role_handler.go b/handler/admin/chat_role_handler.go new file mode 100644 index 0000000..3e69ae3 --- /dev/null +++ b/handler/admin/chat_role_handler.go @@ -0,0 +1,163 @@ +package admin + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/core/types" + "geekai/handler" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + "time" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +type ChatRoleHandler struct { + handler.BaseHandler +} + +func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler { + return &ChatRoleHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} +} + +// Save 创建或者更新某个角色 +func (h *ChatRoleHandler) Save(c *gin.Context) { + var data vo.ChatRole + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + var role model.ChatRole + err := utils.CopyObject(data, &role) + if err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + role.Id = data.Id + if data.CreatedAt > 0 { + role.CreatedAt = time.Unix(data.CreatedAt, 0) + } + res := h.DB.Save(&role) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败!") + return + } + // 填充 ID 数据 + data.Id = role.Id + data.CreatedAt = role.CreatedAt.Unix() + resp.SUCCESS(c, data) +} + +func (h *ChatRoleHandler) List(c *gin.Context) { + var items []model.ChatRole + var roles = make([]vo.ChatRole, 0) + res := h.DB.Order("sort_num ASC").Find(&items) + if res.Error != nil { + resp.ERROR(c, "No data found") + return + } + + // initialize model mane for role + modelIds := make([]int, 0) + for _, v := range items { + if v.ModelId > 0 { + modelIds = append(modelIds, v.ModelId) + } + } + + modelNameMap := make(map[int]string) + if len(modelIds) > 0 { + var models []model.ChatModel + tx := h.DB.Where("id IN ?", modelIds).Find(&models) + if tx.Error == nil { + for _, m := range models { + modelNameMap[int(m.Id)] = m.Name + } + } + } + + for _, v := range items { + var role vo.ChatRole + err := utils.CopyObject(v, &role) + if err == nil { + role.Id = v.Id + role.CreatedAt = v.CreatedAt.Unix() + role.UpdatedAt = v.UpdatedAt.Unix() + role.ModelName = modelNameMap[role.ModelId] + roles = append(roles, role) + } + } + + resp.SUCCESS(c, roles) +} + +// Sort 更新角色排序 +func (h *ChatRoleHandler) Sort(c *gin.Context) { + var data struct { + Ids []uint `json:"ids"` + Sorts []int `json:"sorts"` + } + + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + for index, id := range data.Ids { + res := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index]) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败!") + return + } + } + + resp.SUCCESS(c) +} + +func (h *ChatRoleHandler) Set(c *gin.Context) { + var data struct { + Id uint `json:"id"` + Filed string `json:"filed"` + Value interface{} `json:"value"` + } + + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + res := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败!") + return + } + resp.SUCCESS(c) +} + +func (h *ChatRoleHandler) Remove(c *gin.Context) { + id := h.GetInt(c, "id", 0) + + if id <= 0 { + resp.ERROR(c, types.InvalidArgs) + return + } + res := h.DB.Where("id", id).Delete(&model.ChatRole{}) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "删除失败!") + return + } + resp.SUCCESS(c) +} diff --git a/handler/admin/config_handler.go b/handler/admin/config_handler.go new file mode 100644 index 0000000..584b026 --- /dev/null +++ b/handler/admin/config_handler.go @@ -0,0 +1,197 @@ +package admin + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/core/types" + "geekai/handler" + "geekai/service" + "geekai/service/mj" + "geekai/service/sd" + "geekai/store" + "geekai/store/model" + "geekai/utils" + "geekai/utils/resp" + + "github.com/gin-gonic/gin" + "github.com/shirou/gopsutil/host" + "gorm.io/gorm" +) + +type ConfigHandler struct { + handler.BaseHandler + levelDB *store.LevelDB + licenseService *service.LicenseService + mjServicePool *mj.ServicePool + sdServicePool *sd.ServicePool +} + +func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, licenseService *service.LicenseService, mjPool *mj.ServicePool, sdPool *sd.ServicePool) *ConfigHandler { + return &ConfigHandler{ + BaseHandler: handler.BaseHandler{App: app, DB: db}, + levelDB: levelDB, + mjServicePool: mjPool, + sdServicePool: sdPool, + licenseService: licenseService, + } +} + +func (h *ConfigHandler) Update(c *gin.Context) { + var data struct { + Key string `json:"key"` + Config struct { + types.SystemConfig + Content string `json:"content,omitempty"` + Updated bool `json:"updated,omitempty"` + } `json:"config"` + } + + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + value := utils.JsonEncode(&data.Config) + config := model.Config{Key: data.Key, Config: value} + res := h.DB.FirstOrCreate(&config, model.Config{Key: data.Key}) + if res.Error != nil { + resp.ERROR(c, res.Error.Error()) + return + } + + if config.Id > 0 { + config.Config = value + res := h.DB.Updates(&config) + if res.Error != nil { + resp.ERROR(c, res.Error.Error()) + return + } + + // update config cache for AppServer + var cfg model.Config + h.DB.Where("marker", data.Key).First(&cfg) + var err error + if data.Key == "system" { + err = utils.JsonDecode(cfg.Config, &h.App.SysConfig) + } + if err != nil { + resp.ERROR(c, "Failed to update config cache: "+err.Error()) + return + } + logger.Infof("Update AppServer's config successfully: %v", config.Config) + } + + resp.SUCCESS(c, config) +} + +// Get 获取指定的系统配置 +func (h *ConfigHandler) Get(c *gin.Context) { + key := c.Query("key") + var config model.Config + res := h.DB.Where("marker", key).First(&config) + if res.Error != nil { + resp.ERROR(c, res.Error.Error()) + return + } + + var value map[string]interface{} + err := utils.JsonDecode(config.Config, &value) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c, value) +} + +// Active 激活系统 +func (h *ConfigHandler) Active(c *gin.Context) { + var data struct { + License string `json:"license"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + info, err := host.Info() + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + err = h.licenseService.ActiveLicense(data.License, info.HostID) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c, info.HostID) +} + +// GetLicense 获取 License 信息 +func (h *ConfigHandler) GetLicense(c *gin.Context) { + license := h.licenseService.GetLicense() + resp.SUCCESS(c, license) +} + +// GetAppConfig 获取内置配置 +func (h *ConfigHandler) GetAppConfig(c *gin.Context) { + resp.SUCCESS(c, gin.H{ + "mj_plus": h.App.Config.MjPlusConfigs, + "mj_proxy": h.App.Config.MjProxyConfigs, + "sd": h.App.Config.SdConfigs, + "platforms": Platforms, + }) +} + +// SaveDrawingConfig 保存AI绘画配置 +func (h *ConfigHandler) SaveDrawingConfig(c *gin.Context) { + var data struct { + Sd []types.StableDiffusionConfig `json:"sd"` + MjPlus []types.MjPlusConfig `json:"mj_plus"` + MjProxy []types.MjProxyConfig `json:"mj_proxy"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + changed := false + if configChanged(data.Sd, h.App.Config.SdConfigs) { + logger.Debugf("SD 配置变动了") + h.App.Config.SdConfigs = data.Sd + h.sdServicePool.InitServices(data.Sd) + changed = true + } + + if configChanged(data.MjPlus, h.App.Config.MjPlusConfigs) || configChanged(data.MjProxy, h.App.Config.MjProxyConfigs) { + logger.Debugf("MidJourney 配置变动了") + h.App.Config.MjPlusConfigs = data.MjPlus + h.App.Config.MjProxyConfigs = data.MjProxy + h.mjServicePool.InitServices(data.MjPlus, data.MjProxy) + changed = true + } + + if changed { + err := core.SaveConfig(h.App.Config) + if err != nil { + resp.ERROR(c, "更新配置文档失败!") + return + } + } + + resp.SUCCESS(c) + +} + +func configChanged(c1 interface{}, c2 interface{}) bool { + encode1 := utils.JsonEncode(c1) + encode2 := utils.JsonEncode(c2) + return utils.Md5(encode1) != utils.Md5(encode2) +} diff --git a/handler/admin/dashboard_handler.go b/handler/admin/dashboard_handler.go new file mode 100644 index 0000000..536cd49 --- /dev/null +++ b/handler/admin/dashboard_handler.go @@ -0,0 +1,124 @@ +package admin + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/core/types" + "geekai/handler" + "geekai/store/model" + "geekai/utils/resp" + "github.com/gin-gonic/gin" + "github.com/shopspring/decimal" + "gorm.io/gorm" + "time" +) + +type DashboardHandler struct { + handler.BaseHandler +} + +func NewDashboardHandler(app *core.AppServer, db *gorm.DB) *DashboardHandler { + return &DashboardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} +} + +type statsVo struct { + Users int64 `json:"users"` + Chats int64 `json:"chats"` + Tokens int `json:"tokens"` + Income float64 `json:"income"` + Chart map[string]map[string]float64 `json:"chart"` +} + +func (h *DashboardHandler) Stats(c *gin.Context) { + stats := statsVo{} + // new users statistic + var userCount int64 + now := time.Now() + zeroTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()) + res := h.DB.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&userCount) + if res.Error == nil { + stats.Users = userCount + } + + // new chats statistic + var chatCount int64 + res = h.DB.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&chatCount) + if res.Error == nil { + stats.Chats = chatCount + } + + // tokens took stats + var historyMessages []model.ChatMessage + res = h.DB.Where("created_at > ?", zeroTime).Find(&historyMessages) + for _, item := range historyMessages { + stats.Tokens += item.Tokens + } + + // 众筹收入 + var rewards []model.Reward + res = h.DB.Where("created_at > ?", zeroTime).Find(&rewards) + for _, item := range rewards { + stats.Income += item.Amount + } + + // 订单收入 + var orders []model.Order + res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders) + for _, item := range orders { + stats.Income += item.Amount + } + + // 统计7天的订单的图表 + startDate := now.Add(-7 * 24 * time.Hour).Format("2006-01-02") + var statsChart = make(map[string]map[string]float64) + //// 初始化 + var userStatistic, historyMessagesStatistic, incomeStatistic = make(map[string]float64), make(map[string]float64), make(map[string]float64) + for i := 0; i < 7; i++ { + var initTime = time.Date(now.Year(), now.Month(), now.Day()-i, 0, 0, 0, 0, now.Location()).Format("2006-01-02") + userStatistic[initTime] = float64(0) + historyMessagesStatistic[initTime] = float64(0) + incomeStatistic[initTime] = float64(0) + } + + // 统计用户7天增加的曲线 + var users []model.User + res = h.DB.Model(&model.User{}).Where("created_at > ?", startDate).Find(&users) + if res.Error == nil { + for _, item := range users { + userStatistic[item.CreatedAt.Format("2006-01-02")] += 1 + } + } + + // 统计7天Token 消耗 + res = h.DB.Where("created_at > ?", startDate).Find(&historyMessages) + for _, item := range historyMessages { + historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Tokens) + } + + // 浮点数相加? + // 统计最近7天的众筹 + res = h.DB.Where("created_at > ?", startDate).Find(&rewards) + for _, item := range rewards { + incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64() + } + + // 统计最近7天的订单 + res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders) + for _, item := range orders { + incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64() + } + + statsChart["users"] = userStatistic + statsChart["historyMessage"] = historyMessagesStatistic + statsChart["orders"] = incomeStatistic + + stats.Chart = statsChart + + resp.SUCCESS(c, stats) +} diff --git a/handler/admin/function_handler.go b/handler/admin/function_handler.go new file mode 100644 index 0000000..c9e8005 --- /dev/null +++ b/handler/admin/function_handler.go @@ -0,0 +1,130 @@ +package admin + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/core/types" + "geekai/handler" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + + "github.com/golang-jwt/jwt/v5" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +type FunctionHandler struct { + handler.BaseHandler +} + +func NewFunctionHandler(app *core.AppServer, db *gorm.DB) *FunctionHandler { + return &FunctionHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} +} + +func (h *FunctionHandler) Save(c *gin.Context) { + var data vo.Function + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + var f = model.Function{ + Id: data.Id, + Name: data.Name, + Label: data.Label, + Description: data.Description, + Parameters: utils.JsonEncode(data.Parameters), + Action: data.Action, + Token: data.Token, + Enabled: data.Enabled, + } + + res := h.DB.Save(&f) + if res.Error != nil { + resp.ERROR(c, "error with save data:"+res.Error.Error()) + return + } + data.Id = f.Id + resp.SUCCESS(c, data) +} + +func (h *FunctionHandler) Set(c *gin.Context) { + var data struct { + Id uint `json:"id"` + Filed string `json:"filed"` + Value interface{} `json:"value"` + } + + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + res := h.DB.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败!") + return + } + resp.SUCCESS(c) +} + +func (h *FunctionHandler) List(c *gin.Context) { + var items []model.Function + res := h.DB.Find(&items) + if res.Error != nil { + resp.ERROR(c, "No data found") + return + } + + functions := make([]vo.Function, 0) + for _, v := range items { + var f vo.Function + err := utils.CopyObject(v, &f) + if err != nil { + continue + } + functions = append(functions, f) + } + resp.SUCCESS(c, functions) +} + +func (h *FunctionHandler) Remove(c *gin.Context) { + id := h.GetInt(c, "id", 0) + + if id > 0 { + res := h.DB.Delete(&model.Function{Id: uint(id)}) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败!") + return + } + } + resp.SUCCESS(c) +} + +// GenToken generate function api access token +func (h *FunctionHandler) GenToken(c *gin.Context) { + // 创建 token + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "user_id": 0, + "expired": 0, + }) + tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey)) + if err != nil { + logger.Error("error with generate token", err) + resp.ERROR(c) + return + } + + resp.SUCCESS(c, tokenString) +} diff --git a/handler/admin/menu_handler.go b/handler/admin/menu_handler.go new file mode 100644 index 0000000..001c601 --- /dev/null +++ b/handler/admin/menu_handler.go @@ -0,0 +1,132 @@ +package admin + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/core/types" + "geekai/handler" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +type MenuHandler struct { + handler.BaseHandler +} + +func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler { + return &MenuHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} +} + +func (h *MenuHandler) Save(c *gin.Context) { + var data struct { + Id uint `json:"id"` + Name string `json:"name"` + Icon string `json:"icon"` + URL string `json:"url"` + SortNum int `json:"sort_num"` + Enabled bool `json:"enabled"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + res := h.DB.Save(&model.Menu{ + Id: data.Id, + Name: data.Name, + Icon: data.Icon, + URL: data.URL, + SortNum: data.SortNum, + Enabled: data.Enabled, + }) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败!") + return + } + resp.SUCCESS(c) +} + +// List 数据列表 +func (h *MenuHandler) List(c *gin.Context) { + var items []model.Menu + var list = make([]vo.Menu, 0) + res := h.DB.Order("sort_num ASC").Find(&items) + if res.Error == nil { + for _, item := range items { + var product vo.Menu + err := utils.CopyObject(item, &product) + if err == nil { + list = append(list, product) + } + } + } + resp.SUCCESS(c, list) +} + +func (h *MenuHandler) Enable(c *gin.Context) { + var data struct { + Id uint `json:"id"` + Enabled bool `json:"enabled"` + } + + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + res := h.DB.Model(&model.Menu{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败!") + return + } + resp.SUCCESS(c) +} + +func (h *MenuHandler) Sort(c *gin.Context) { + var data struct { + Ids []uint `json:"ids"` + Sorts []int `json:"sorts"` + } + + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + for index, id := range data.Ids { + res := h.DB.Model(&model.Menu{}).Where("id", id).Update("sort_num", data.Sorts[index]) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败!") + return + } + } + + resp.SUCCESS(c) +} + +func (h *MenuHandler) Remove(c *gin.Context) { + id := h.GetInt(c, "id", 0) + + if id > 0 { + res := h.DB.Where("id", id).Delete(&model.Menu{}) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败!") + return + } + } + resp.SUCCESS(c) +} diff --git a/handler/admin/order_handler.go b/handler/admin/order_handler.go new file mode 100644 index 0000000..ab6752b --- /dev/null +++ b/handler/admin/order_handler.go @@ -0,0 +1,103 @@ +package admin + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/core/types" + "geekai/handler" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +type OrderHandler struct { + handler.BaseHandler +} + +func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler { + return &OrderHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} +} + +func (h *OrderHandler) List(c *gin.Context) { + var data struct { + OrderNo string `json:"order_no"` + Status int `json:"status"` + PayTime []string `json:"pay_time"` + Page int `json:"page"` + PageSize int `json:"page_size"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + session := h.DB.Session(&gorm.Session{}) + if data.OrderNo != "" { + session = session.Where("order_no", data.OrderNo) + } + if len(data.PayTime) == 2 { + start := utils.Str2stamp(data.PayTime[0] + " 00:00:00") + end := utils.Str2stamp(data.PayTime[1] + " 00:00:00") + session = session.Where("pay_time >= ? AND pay_time <= ?", start, end) + } + if data.Status >= 0 { + session = session.Where("status", data.Status) + } + var total int64 + session.Model(&model.Order{}).Count(&total) + var items []model.Order + var list = make([]vo.Order, 0) + offset := (data.Page - 1) * data.PageSize + res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items) + if res.Error == nil { + for _, item := range items { + var order vo.Order + err := utils.CopyObject(item, &order) + if err == nil { + order.Id = item.Id + order.CreatedAt = item.CreatedAt.Unix() + order.UpdatedAt = item.UpdatedAt.Unix() + list = append(list, order) + } else { + logger.Error(err) + } + } + } + resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list)) +} + +func (h *OrderHandler) Remove(c *gin.Context) { + id := h.GetInt(c, "id", 0) + + if id > 0 { + var item model.Order + res := h.DB.First(&item, id) + if res.Error != nil { + resp.ERROR(c, "记录不存在!") + return + } + + if item.Status == types.OrderPaidSuccess { + resp.ERROR(c, "已支付订单不允许删除!") + return + } + + res = h.DB.Unscoped().Where("id = ?", id).Delete(&model.Order{}) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败!") + return + } + } + resp.SUCCESS(c) +} diff --git a/handler/admin/power_log_handler.go b/handler/admin/power_log_handler.go new file mode 100644 index 0000000..56f63c8 --- /dev/null +++ b/handler/admin/power_log_handler.go @@ -0,0 +1,84 @@ +package admin + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/core/types" + "geekai/handler" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +type PowerLogHandler struct { + handler.BaseHandler +} + +func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler { + return &PowerLogHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} +} + +func (h *PowerLogHandler) List(c *gin.Context) { + var data struct { + Username string `json:"username"` + Type int `json:"type"` + Model string `json:"model"` + Date []string `json:"date"` + Page int `json:"page"` + PageSize int `json:"page_size"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + session := h.DB.Session(&gorm.Session{}) + if data.Model != "" { + session = session.Where("model", data.Model) + } + if data.Type > 0 { + session = session.Where("type", data.Type) + } + if len(data.Date) == 2 { + start := data.Date[0] + " 00:00:00" + end := data.Date[1] + " 00:00:00" + session = session.Where("created_at >= ? AND created_at <= ?", start, end) + } + + var total int64 + session.Model(&model.PowerLog{}).Count(&total) + var items []model.PowerLog + var list = make([]vo.PowerLog, 0) + offset := (data.Page - 1) * data.PageSize + res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items) + if res.Error == nil { + for _, item := range items { + var log vo.PowerLog + err := utils.CopyObject(item, &log) + if err != nil { + continue + } + log.Id = item.Id + log.CreatedAt = item.CreatedAt.Unix() + log.TypeStr = item.Type.String() + list = append(list, log) + } + } + + // 统计消费算力总和 + var totalPower float64 + if len(data.Date) == 2 { + session.Where("mark", 0).Select("SUM(amount) as total_sum").Scan(&totalPower) + } + resp.SUCCESS(c, gin.H{"data": vo.NewPage(total, data.Page, data.PageSize, list), "stat": totalPower}) +} diff --git a/handler/admin/product_handler.go b/handler/admin/product_handler.go new file mode 100644 index 0000000..e2b66c7 --- /dev/null +++ b/handler/admin/product_handler.go @@ -0,0 +1,153 @@ +package admin + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/core/types" + "geekai/handler" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + "github.com/gin-gonic/gin" + "gorm.io/gorm" + "time" +) + +type ProductHandler struct { + handler.BaseHandler +} + +func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler { + return &ProductHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} +} + +func (h *ProductHandler) Save(c *gin.Context) { + var data struct { + Id uint `json:"id"` + Name string `json:"name"` + Price float64 `json:"price"` + Discount float64 `json:"discount"` + Enabled bool `json:"enabled"` + Days int `json:"days"` + Power int `json:"power"` + CreatedAt int64 `json:"created_at"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + item := model.Product{ + Name: data.Name, + Price: data.Price, + Discount: data.Discount, + Days: data.Days, + Power: data.Power, + Enabled: data.Enabled} + item.Id = data.Id + if item.Id > 0 { + item.CreatedAt = time.Unix(data.CreatedAt, 0) + } + res := h.DB.Save(&item) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败!") + return + } + + var itemVo vo.Product + err := utils.CopyObject(item, &itemVo) + if err != nil { + resp.ERROR(c, "数据拷贝失败!") + return + } + itemVo.Id = item.Id + itemVo.UpdatedAt = item.UpdatedAt.Unix() + resp.SUCCESS(c, itemVo) +} + +// List 数据列表 +func (h *ProductHandler) List(c *gin.Context) { + var items []model.Product + var list = make([]vo.Product, 0) + res := h.DB.Order("sort_num ASC").Find(&items) + if res.Error == nil { + for _, item := range items { + var product vo.Product + err := utils.CopyObject(item, &product) + if err == nil { + product.Id = item.Id + product.CreatedAt = item.CreatedAt.Unix() + product.UpdatedAt = item.UpdatedAt.Unix() + list = append(list, product) + } else { + logger.Error(err) + } + } + } + resp.SUCCESS(c, list) +} + +func (h *ProductHandler) Enable(c *gin.Context) { + var data struct { + Id uint `json:"id"` + Enabled bool `json:"enabled"` + } + + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + res := h.DB.Model(&model.Product{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败!") + return + } + resp.SUCCESS(c) +} + +func (h *ProductHandler) Sort(c *gin.Context) { + var data struct { + Ids []uint `json:"ids"` + Sorts []int `json:"sorts"` + } + + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + for index, id := range data.Ids { + res := h.DB.Model(&model.Product{}).Where("id", id).Update("sort_num", data.Sorts[index]) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败!") + return + } + } + + resp.SUCCESS(c) +} + +func (h *ProductHandler) Remove(c *gin.Context) { + id := h.GetInt(c, "id", 0) + + if id > 0 { + res := h.DB.Where("id", id).Delete(&model.Product{}) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败!") + return + } + } + resp.SUCCESS(c) +} diff --git a/handler/admin/reward_handler.go b/handler/admin/reward_handler.go new file mode 100644 index 0000000..2dc2e28 --- /dev/null +++ b/handler/admin/reward_handler.go @@ -0,0 +1,81 @@ +package admin + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/core/types" + "geekai/handler" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +type RewardHandler struct { + handler.BaseHandler +} + +func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler { + return &RewardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} +} + +func (h *RewardHandler) List(c *gin.Context) { + var items []model.Reward + res := h.DB.Order("id DESC").Find(&items) + var rewards = make([]vo.Reward, 0) + if res.Error == nil { + userIds := make([]uint, 0) + for _, v := range items { + userIds = append(userIds, v.UserId) + } + var users []model.User + h.DB.Where("id IN ?", userIds).Find(&users) + var userMap = make(map[uint]model.User) + for _, u := range users { + userMap[u.Id] = u + } + + for _, v := range items { + var r vo.Reward + err := utils.CopyObject(v, &r) + if err != nil { + continue + } + + r.Id = v.Id + r.Username = userMap[v.UserId].Username + r.CreatedAt = v.CreatedAt.Unix() + r.UpdatedAt = v.UpdatedAt.Unix() + rewards = append(rewards, r) + } + } + + resp.SUCCESS(c, rewards) +} + +func (h *RewardHandler) Remove(c *gin.Context) { + var data struct { + Id uint + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + if data.Id > 0 { + res := h.DB.Where("id = ?", data.Id).Delete(&model.Reward{}) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败!") + return + } + } + resp.SUCCESS(c) +} diff --git a/handler/admin/types.go b/handler/admin/types.go new file mode 100644 index 0000000..c06139b --- /dev/null +++ b/handler/admin/types.go @@ -0,0 +1,12 @@ +package admin + +import "geekai/core/types" + +var Platforms = []types.Platform{ + types.OpenAI, + types.QWen, + types.XunFei, + types.ChatGLM, + types.Baidu, + types.Azure, +} diff --git a/handler/admin/upload_handler.go b/handler/admin/upload_handler.go new file mode 100644 index 0000000..6467432 --- /dev/null +++ b/handler/admin/upload_handler.go @@ -0,0 +1,52 @@ +package admin + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/handler" + "geekai/service/oss" + "geekai/store/model" + "geekai/utils/resp" + "github.com/gin-gonic/gin" + "gorm.io/gorm" + "time" +) + +type UploadHandler struct { + handler.BaseHandler + uploaderManager *oss.UploaderManager +} + +func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler { + return &UploadHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, uploaderManager: manager} +} + +func (h *UploadHandler) Upload(c *gin.Context) { + file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file") + if err != nil { + resp.ERROR(c, err.Error()) + return + } + userId := 0 + res := h.DB.Create(&model.File{ + UserId: userId, + Name: file.Name, + ObjKey: file.ObjKey, + URL: file.URL, + Ext: file.Ext, + Size: file.Size, + CreatedAt: time.Time{}, + }) + if res.Error != nil || res.RowsAffected == 0 { + resp.ERROR(c, "error with update database: "+res.Error.Error()) + return + } + + resp.SUCCESS(c, file) +} diff --git a/handler/admin/user_handler.go b/handler/admin/user_handler.go new file mode 100644 index 0000000..95da105 --- /dev/null +++ b/handler/admin/user_handler.go @@ -0,0 +1,251 @@ +package admin + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "fmt" + "geekai/core" + "geekai/core/types" + "geekai/handler" + "geekai/service" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + "time" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +type UserHandler struct { + handler.BaseHandler + licenseService *service.LicenseService +} + +func NewUserHandler(app *core.AppServer, db *gorm.DB, licenseService *service.LicenseService) *UserHandler { + return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, licenseService: licenseService} +} + +// List 用户列表 +func (h *UserHandler) List(c *gin.Context) { + page := h.GetInt(c, "page", 1) + pageSize := h.GetInt(c, "page_size", 20) + username := h.GetTrim(c, "username") + + offset := (page - 1) * pageSize + var items []model.User + var users = make([]vo.User, 0) + var total int64 + + session := h.DB.Session(&gorm.Session{}) + if username != "" { + session = session.Where("username LIKE ?", "%"+username+"%") + } + + session.Model(&model.User{}).Count(&total) + res := session.Offset(offset).Limit(pageSize).Find(&items) + if res.Error == nil { + for _, item := range items { + var user vo.User + err := utils.CopyObject(item, &user) + if err == nil { + user.Id = item.Id + user.CreatedAt = item.CreatedAt.Unix() + user.UpdatedAt = item.UpdatedAt.Unix() + users = append(users, user) + } else { + logger.Error(err) + } + } + } + pageVo := vo.NewPage(total, page, pageSize, users) + resp.SUCCESS(c, pageVo) +} + +func (h *UserHandler) Save(c *gin.Context) { + var data struct { + Id uint `json:"id"` + Password string `json:"password"` + Username string `json:"username"` + ChatRoles []string `json:"chat_roles"` + ChatModels []int `json:"chat_models"` + ExpiredTime string `json:"expired_time"` + Status bool `json:"status"` + Vip bool `json:"vip"` + Power int `json:"power"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + // 检测最大注册人数 + var totalUser int64 + h.DB.Model(&model.User{}).Count(&totalUser) + if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum { + resp.ERROR(c, "当前注册用户数已达上限,请请升级 License") + return + } + var user = model.User{} + var res *gorm.DB + var userVo vo.User + if data.Id > 0 { // 更新 + res = h.DB.Where("id", data.Id).First(&user) + if res.Error != nil { + resp.ERROR(c, "user not found") + return + } + var oldPower = user.Power + user.Username = data.Username + user.Status = data.Status + user.Vip = data.Vip + user.Power = data.Power + user.ChatRoles = utils.JsonEncode(data.ChatRoles) + user.ChatModels = utils.JsonEncode(data.ChatModels) + user.ExpiredTime = utils.Str2stamp(data.ExpiredTime) + + res = h.DB.Select("username", "status", "vip", "power", "chat_roles_json", "chat_models_json", "expired_time").Updates(&user) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败!") + return + } + // 记录算力日志 + if oldPower != user.Power { + mark := types.PowerAdd + amount := user.Power - oldPower + if oldPower > user.Power { + mark = types.PowerSub + amount = oldPower - user.Power + } + h.DB.Create(&model.PowerLog{ + UserId: user.Id, + Username: user.Username, + Type: types.PowerGift, + Amount: amount, + Balance: user.Power, + Mark: mark, + Model: "管理员", + Remark: fmt.Sprintf("后台管理员强制修改用户算力,修改前:%d,修改后:%d, 管理员ID:%d", oldPower, user.Power, h.GetLoginUserId(c)), + CreatedAt: time.Now(), + }) + } + } else { + salt := utils.RandString(8) + u := model.User{ + Username: data.Username, + Nickname: fmt.Sprintf("极客学长@%d", utils.RandomNumber(6)), + Password: utils.GenPassword(data.Password, salt), + Avatar: "/images/avatar/user.png", + Salt: salt, + Power: data.Power, + Status: true, + ChatRoles: utils.JsonEncode(data.ChatRoles), + ChatModels: utils.JsonEncode(data.ChatModels), + ExpiredTime: utils.Str2stamp(data.ExpiredTime), + } + res = h.DB.Create(&u) + _ = utils.CopyObject(u, &userVo) + userVo.Id = u.Id + userVo.CreatedAt = u.CreatedAt.Unix() + userVo.UpdatedAt = u.UpdatedAt.Unix() + } + + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败") + return + } + + resp.SUCCESS(c, userVo) +} + +// ResetPass 重置密码 +func (h *UserHandler) ResetPass(c *gin.Context) { + var data struct { + Id uint + Password string + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + var user model.User + res := h.DB.First(&user, data.Id) + if res.Error != nil { + resp.ERROR(c, "No user found") + return + } + + password := utils.GenPassword(data.Password, user.Salt) + user.Password = password + res = h.DB.Updates(&user) + if res.Error != nil { + resp.ERROR(c) + } else { + resp.SUCCESS(c) + } +} + +func (h *UserHandler) Remove(c *gin.Context) { + id := h.GetInt(c, "id", 0) + if id <= 0 { + resp.ERROR(c, types.InvalidArgs) + return + } + // 删除用户 + res := h.DB.Where("id = ?", id).Delete(&model.User{}) + if res.Error != nil { + resp.ERROR(c, "删除失败") + return + } + + // 删除聊天记录 + h.DB.Where("user_id = ?", id).Delete(&model.ChatItem{}) + // 删除聊天历史记录 + h.DB.Where("user_id = ?", id).Delete(&model.ChatMessage{}) + // 删除登录日志 + h.DB.Where("user_id = ?", id).Delete(&model.UserLoginLog{}) + // 删除算力日志 + h.DB.Where("user_id = ?", id).Delete(&model.PowerLog{}) + // 删除众筹日志 + h.DB.Where("user_id = ?", id).Delete(&model.Reward{}) + // 删除绘图任务 + h.DB.Where("user_id = ?", id).Delete(&model.MidJourneyJob{}) + h.DB.Where("user_id = ?", id).Delete(&model.SdJob{}) + // 删除订单 + h.DB.Where("user_id = ?", id).Delete(&model.Order{}) + resp.SUCCESS(c) +} + +func (h *UserHandler) LoginLog(c *gin.Context) { + page := h.GetInt(c, "page", 1) + pageSize := h.GetInt(c, "page_size", 20) + var total int64 + h.DB.Model(&model.UserLoginLog{}).Count(&total) + offset := (page - 1) * pageSize + var items []model.UserLoginLog + res := h.DB.Offset(offset).Limit(pageSize).Order("id DESC").Find(&items) + if res.Error != nil { + resp.ERROR(c, "获取数据失败") + return + } + var logs []vo.UserLoginLog + for _, v := range items { + var log vo.UserLoginLog + err := utils.CopyObject(v, &log) + if err == nil { + log.Id = v.Id + log.CreatedAt = v.CreatedAt.Unix() + logs = append(logs, log) + } + } + + resp.SUCCESS(c, vo.NewPage(total, page, pageSize, logs)) +} diff --git a/handler/base_handler.go b/handler/base_handler.go new file mode 100644 index 0000000..406b9b5 --- /dev/null +++ b/handler/base_handler.go @@ -0,0 +1,94 @@ +package handler + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/core/types" + logger2 "geekai/logger" + "geekai/store/model" + "geekai/utils" + "errors" + "fmt" + "gorm.io/gorm" + "strings" + + "github.com/gin-gonic/gin" +) + +var logger = logger2.GetLogger() + +type BaseHandler struct { + App *core.AppServer + DB *gorm.DB +} + +func (h *BaseHandler) GetTrim(c *gin.Context, key string) string { + return strings.TrimSpace(c.Query(key)) +} + +func (h *BaseHandler) PostInt(c *gin.Context, key string, defaultValue int) int { + return utils.IntValue(c.PostForm(key), defaultValue) +} + +func (h *BaseHandler) GetInt(c *gin.Context, key string, defaultValue int) int { + return utils.IntValue(c.Query(key), defaultValue) +} + +func (h *BaseHandler) GetFloat(c *gin.Context, key string) float64 { + return utils.FloatValue(c.Query(key)) +} +func (h *BaseHandler) PostFloat(c *gin.Context, key string) float64 { + return utils.FloatValue(c.PostForm(key)) +} + +func (h *BaseHandler) GetBool(c *gin.Context, key string) bool { + return utils.BoolValue(c.Query(key)) +} +func (h *BaseHandler) PostBool(c *gin.Context, key string) bool { + return utils.BoolValue(c.PostForm(key)) +} +func (h *BaseHandler) GetUserKey(c *gin.Context) string { + userId, ok := c.Get(types.LoginUserID) + if !ok { + return "" + } + return fmt.Sprintf("users/%v", userId) +} + +func (h *BaseHandler) GetLoginUserId(c *gin.Context) uint { + userId, ok := c.Get(types.LoginUserID) + if !ok { + return 0 + } + return uint(utils.IntValue(utils.InterfaceToString(userId), 0)) +} + +func (h *BaseHandler) IsLogin(c *gin.Context) bool { + return h.GetLoginUserId(c) > 0 +} + +func (h *BaseHandler) GetLoginUser(c *gin.Context) (model.User, error) { + value, exists := c.Get(types.LoginUserCache) + if exists { + return value.(model.User), nil + } + + userId, ok := c.Get(types.LoginUserID) + if !ok { + return model.User{}, errors.New("user not login") + } + + var user model.User + res := h.DB.First(&user, userId) + // 更新缓存 + if res.Error == nil { + c.Set(types.LoginUserCache, user) + } + return user, res.Error +} diff --git a/handler/captcha_handler.go b/handler/captcha_handler.go new file mode 100644 index 0000000..57852b4 --- /dev/null +++ b/handler/captcha_handler.go @@ -0,0 +1,84 @@ +package handler + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core/types" + "geekai/service" + "geekai/utils/resp" + "github.com/gin-gonic/gin" +) + +// 今日头条函数实现 + +type CaptchaHandler struct { + service *service.CaptchaService +} + +func NewCaptchaHandler(s *service.CaptchaService) *CaptchaHandler { + return &CaptchaHandler{service: s} +} + +func (h *CaptchaHandler) Get(c *gin.Context) { + data, err := h.service.Get() + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c, data) +} + +// Check verify the captcha data +func (h *CaptchaHandler) Check(c *gin.Context) { + var data struct { + Key string `json:"key"` + Dots string `json:"dots"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + if h.service.Check(data) { + resp.SUCCESS(c) + } else { + resp.ERROR(c) + } + +} + +// SlideGet 获取滑动验证图片 +func (h *CaptchaHandler) SlideGet(c *gin.Context) { + data, err := h.service.SlideGet() + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c, data) +} + +// SlideCheck 滑动验证结果校验 +func (h *CaptchaHandler) SlideCheck(c *gin.Context) { + var data struct { + Key string `json:"key"` + X int `json:"x"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + if h.service.SlideCheck(data) { + resp.SUCCESS(c) + } else { + resp.ERROR(c) + } + +} diff --git a/handler/chat_model_handler.go b/handler/chat_model_handler.go new file mode 100644 index 0000000..555de7c --- /dev/null +++ b/handler/chat_model_handler.go @@ -0,0 +1,66 @@ +package handler + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +type ChatModelHandler struct { + BaseHandler +} + +func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler { + return &ChatModelHandler{BaseHandler: BaseHandler{App: app, DB: db}} +} + +// List 模型列表 +func (h *ChatModelHandler) List(c *gin.Context) { + var items []model.ChatModel + var chatModels = make([]vo.ChatModel, 0) + var res *gorm.DB + // 如果用户没有登录,则加载所有开放模型 + if !h.IsLogin(c) { + res = h.DB.Where("enabled", true).Where("open", true).Order("sort_num ASC").Find(&items) + } else { + user, _ := h.GetLoginUser(c) + var models []int + err := utils.JsonDecode(user.ChatModels, &models) + if err != nil { + resp.ERROR(c, "当前用户没有订阅任何模型") + return + } + // 查询用户有权限访问的模型以及所有开放的模型 + res = h.DB.Where("enabled = ?", true).Where( + h.DB.Where("id IN ?", models).Or("open", true), + ).Order("sort_num ASC").Find(&items) + } + + if res.Error == nil { + for _, item := range items { + var cm vo.ChatModel + err := utils.CopyObject(item, &cm) + if err == nil { + cm.Id = item.Id + cm.CreatedAt = item.CreatedAt.Unix() + cm.UpdatedAt = item.UpdatedAt.Unix() + chatModels = append(chatModels, cm) + } else { + logger.Error(err) + } + } + } + resp.SUCCESS(c, chatModels) +} diff --git a/handler/chat_role_handler.go b/handler/chat_role_handler.go new file mode 100644 index 0000000..707e7f4 --- /dev/null +++ b/handler/chat_role_handler.go @@ -0,0 +1,105 @@ +package handler + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/core/types" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +type ChatRoleHandler struct { + BaseHandler +} + +func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler { + return &ChatRoleHandler{BaseHandler: BaseHandler{App: app, DB: db}} +} + +// List 获取用户聊天应用列表 +func (h *ChatRoleHandler) List(c *gin.Context) { + all := h.GetBool(c, "all") + userId := h.GetLoginUserId(c) + var roles []model.ChatRole + var roleVos = make([]vo.ChatRole, 0) + res := h.DB.Where("enable", true).Order("sort_num ASC").Find(&roles) + if res.Error != nil { + resp.SUCCESS(c, roleVos) + return + } + + // 获取所有角色 + if userId == 0 || all { + // 转成 vo + var roleVos = make([]vo.ChatRole, 0) + for _, r := range roles { + var v vo.ChatRole + err := utils.CopyObject(r, &v) + if err == nil { + v.Id = r.Id + roleVos = append(roleVos, v) + } + } + resp.SUCCESS(c, roleVos) + return + } + + var user model.User + h.DB.First(&user, userId) + var roleKeys []string + err := utils.JsonDecode(user.ChatRoles, &roleKeys) + if err != nil { + resp.ERROR(c, "角色解析失败!") + return + } + + for _, r := range roles { + if !utils.ContainsStr(roleKeys, r.Key) { + continue + } + var v vo.ChatRole + err := utils.CopyObject(r, &v) + if err == nil { + v.Id = r.Id + roleVos = append(roleVos, v) + } + } + resp.SUCCESS(c, roleVos) +} + +// UpdateRole 更新用户聊天角色 +func (h *ChatRoleHandler) UpdateRole(c *gin.Context) { + user, err := h.GetLoginUser(c) + if err != nil { + resp.NotAuth(c) + return + } + + var data struct { + Keys []string `json:"keys"` + } + if err = c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + res := h.DB.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys)) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败!") + return + } + + resp.SUCCESS(c) +} diff --git a/handler/chatimpl/azure_handler.go b/handler/chatimpl/azure_handler.go new file mode 100644 index 0000000..bd28d72 --- /dev/null +++ b/handler/chatimpl/azure_handler.go @@ -0,0 +1,111 @@ +package chatimpl + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "geekai/core/types" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "io" + "strings" + "time" +) + +// 微软 Azure 模型消息发送实现 + +func (h *ChatHandler) sendAzureMessage( + chatCtx []types.Message, + req types.ApiRequest, + userVo vo.User, + ctx context.Context, + session *types.ChatSession, + role model.ChatRole, + prompt string, + ws *types.WsClient) error { + promptCreatedAt := time.Now() // 记录提问时间 + start := time.Now() + var apiKey = model.ApiKey{} + response, err := h.doRequest(ctx, req, session, &apiKey) + logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) + if err != nil { + if strings.Contains(err.Error(), "context canceled") { + return fmt.Errorf("用户取消了请求:%s", prompt) + } else if strings.Contains(err.Error(), "no available key") { + return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") + } + return err + } else { + defer response.Body.Close() + } + + contentType := response.Header.Get("Content-Type") + if strings.Contains(contentType, "text/event-stream") { + replyCreatedAt := time.Now() // 记录回复时间 + // 循环读取 Chunk 消息 + var message = types.Message{} + var contents = make([]string, 0) + scanner := bufio.NewScanner(response.Body) + for scanner.Scan() { + line := scanner.Text() + if !strings.Contains(line, "data:") || len(line) < 30 { + continue + } + + var responseBody = types.ApiResponse{} + err = json.Unmarshal([]byte(line[6:]), &responseBody) + if err != nil { // 数据解析出错 + return errors.New(line) + } + + if len(responseBody.Choices) == 0 { + continue + } + + // 初始化 role + if responseBody.Choices[0].Delta.Role != "" && message.Role == "" { + message.Role = responseBody.Choices[0].Delta.Role + utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) + continue + } else if responseBody.Choices[0].FinishReason != "" { + break // 输出完成或者输出中断了 + } else { + content := responseBody.Choices[0].Delta.Content + contents = append(contents, utils.InterfaceToString(content)) + utils.ReplyChunkMessage(ws, types.WsMessage{ + Type: types.WsMiddle, + Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), + }) + } + } // end for + + if err := scanner.Err(); err != nil { + if strings.Contains(err.Error(), "context canceled") { + logger.Info("用户取消了请求:", prompt) + } else { + logger.Error("信息读取出错:", err) + } + } + + // 消息发送成功 + if len(contents) > 0 { + h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt) + } + + } else { + body, _ := io.ReadAll(response.Body) + return fmt.Errorf("请求大模型 API 失败:%s", body) + } + + return nil +} diff --git a/handler/chatimpl/baidu_handler.go b/handler/chatimpl/baidu_handler.go new file mode 100644 index 0000000..783ac3e --- /dev/null +++ b/handler/chatimpl/baidu_handler.go @@ -0,0 +1,185 @@ +package chatimpl + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "geekai/core/types" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "io" + "net/http" + "strings" + "time" +) + +type baiduResp struct { + Id string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + SentenceId int `json:"sentence_id"` + IsEnd bool `json:"is_end"` + IsTruncated bool `json:"is_truncated"` + Result string `json:"result"` + NeedClearHistory bool `json:"need_clear_history"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` +} + +// 百度文心一言消息发送实现 + +func (h *ChatHandler) sendBaiduMessage( + chatCtx []types.Message, + req types.ApiRequest, + userVo vo.User, + ctx context.Context, + session *types.ChatSession, + role model.ChatRole, + prompt string, + ws *types.WsClient) error { + promptCreatedAt := time.Now() // 记录提问时间 + start := time.Now() + var apiKey = model.ApiKey{} + response, err := h.doRequest(ctx, req, session, &apiKey) + logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) + if err != nil { + logger.Error(err) + if strings.Contains(err.Error(), "context canceled") { + return fmt.Errorf("用户取消了请求:%s", prompt) + } else if strings.Contains(err.Error(), "no available key") { + return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") + } + return err + } else { + defer response.Body.Close() + } + + contentType := response.Header.Get("Content-Type") + if strings.Contains(contentType, "text/event-stream") { + replyCreatedAt := time.Now() // 记录回复时间 + // 循环读取 Chunk 消息 + var message = types.Message{} + var contents = make([]string, 0) + var content string + scanner := bufio.NewScanner(response.Body) + for scanner.Scan() { + line := scanner.Text() + if len(line) < 5 || strings.HasPrefix(line, "id:") { + continue + } + + if strings.HasPrefix(line, "data:") { + content = line[5:] + } + + // 处理代码换行 + if len(content) == 0 { + content = "\n" + } + + var resp baiduResp + err := utils.JsonDecode(content, &resp) + if err != nil { + logger.Error("error with parse data line: ", err) + utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err)) + break + } + + if len(contents) == 0 { + utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) + } + utils.ReplyChunkMessage(ws, types.WsMessage{ + Type: types.WsMiddle, + Content: utils.InterfaceToString(resp.Result), + }) + contents = append(contents, resp.Result) + + if resp.IsTruncated { + utils.ReplyMessage(ws, "AI 输出异常中断") + break + } + + if resp.IsEnd { + break + } + + } // end for + + if err := scanner.Err(); err != nil { + if strings.Contains(err.Error(), "context canceled") { + logger.Info("用户取消了请求:", prompt) + } else { + logger.Error("信息读取出错:", err) + } + } + + // 消息发送成功 + if len(contents) > 0 { + h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt) + } + } else { + body, _ := io.ReadAll(response.Body) + return fmt.Errorf("请求大模型 API 失败:%s", body) + } + + return nil +} + +func (h *ChatHandler) getBaiduToken(apiKey string) (string, error) { + ctx := context.Background() + tokenString, err := h.redis.Get(ctx, apiKey).Result() + if err == nil { + return tokenString, nil + } + + expr := time.Hour * 24 * 20 // access_token 有效期 + key := strings.Split(apiKey, "|") + if len(key) != 2 { + return "", fmt.Errorf("invalid api key: %s", apiKey) + } + url := fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?client_id=%s&client_secret=%s&grant_type=client_credentials", key[0], key[1]) + client := &http.Client{} + req, err := http.NewRequest("POST", url, nil) + if err != nil { + return "", err + } + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Accept", "application/json") + + res, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("error with send request: %w", err) + } + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + if err != nil { + return "", fmt.Errorf("error with read response: %w", err) + } + var r map[string]interface{} + err = json.Unmarshal(body, &r) + if err != nil { + return "", fmt.Errorf("error with parse response: %w", err) + } + + if r["error"] != nil { + return "", fmt.Errorf("error with api response: %s", r["error_description"]) + } + + tokenString = fmt.Sprintf("%s", r["access_token"]) + h.redis.Set(ctx, apiKey, tokenString, expr) + return tokenString, nil +} diff --git a/handler/chatimpl/chat_handler.go b/handler/chatimpl/chat_handler.go new file mode 100644 index 0000000..4ad6965 --- /dev/null +++ b/handler/chatimpl/chat_handler.go @@ -0,0 +1,701 @@ +package chatimpl + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "geekai/core" + "geekai/core/types" + "geekai/handler" + logger2 "geekai/logger" + "geekai/service" + "geekai/service/oss" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + "html/template" + "net/http" + "net/url" + "regexp" + "strings" + "time" + "unicode/utf8" + + "github.com/gin-gonic/gin" + "github.com/go-redis/redis/v8" + "github.com/gorilla/websocket" + "gorm.io/gorm" +) + +var logger = logger2.GetLogger() + +type ChatHandler struct { + handler.BaseHandler + redis *redis.Client + uploadManager *oss.UploaderManager + licenseService *service.LicenseService +} + +func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService) *ChatHandler { + return &ChatHandler{ + BaseHandler: handler.BaseHandler{App: app, DB: db}, + redis: redis, + uploadManager: manager, + licenseService: licenseService, + } +} + +// ChatHandle 处理聊天 WebSocket 请求 +func (h *ChatHandler) ChatHandle(c *gin.Context) { + ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) + if err != nil { + logger.Error(err) + return + } + + sessionId := c.Query("session_id") + roleId := h.GetInt(c, "role_id", 0) + chatId := c.Query("chat_id") + modelId := h.GetInt(c, "model_id", 0) + + client := types.NewWsClient(ws) + var chatRole model.ChatRole + res := h.DB.First(&chatRole, roleId) + if res.Error != nil || !chatRole.Enable { + utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!") + c.Abort() + return + } + // if the role bind a model_id, use role's bind model_id + if chatRole.ModelId > 0 { + modelId = chatRole.ModelId + } + // get model info + var chatModel model.ChatModel + res = h.DB.First(&chatModel, modelId) + if res.Error != nil || chatModel.Enabled == false { + utils.ReplyMessage(client, "当前AI模型暂未启用,连接已关闭!!!") + c.Abort() + return + } + + session := h.App.ChatSession.Get(sessionId) + if session == nil { + user, err := h.GetLoginUser(c) + if err != nil { + logger.Info("用户未登录") + c.Abort() + return + } + session = &types.ChatSession{ + SessionId: sessionId, + ClientIP: c.ClientIP(), + Username: user.Username, + UserId: user.Id, + } + h.App.ChatSession.Put(sessionId, session) + } + + // use old chat data override the chat model and role ID + var chat model.ChatItem + res = h.DB.Where("chat_id = ?", chatId).First(&chat) + if res.Error == nil { + chatModel.Id = chat.ModelId + roleId = int(chat.RoleId) + } + + session.ChatId = chatId + session.Model = types.ChatModel{ + Id: chatModel.Id, + Name: chatModel.Name, + Value: chatModel.Value, + Power: chatModel.Power, + MaxTokens: chatModel.MaxTokens, + MaxContext: chatModel.MaxContext, + Temperature: chatModel.Temperature, + KeyId: chatModel.KeyId, + Platform: chatModel.Platform} + logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username) + + // 保存会话连接 + h.App.ChatClients.Put(sessionId, client) + go func() { + for { + _, msg, err := client.Receive() + if err != nil { + logger.Debugf("close connection: %s", client.Conn.RemoteAddr()) + client.Close() + h.App.ChatClients.Delete(sessionId) + h.App.ChatSession.Delete(sessionId) + cancelFunc := h.App.ReqCancelFunc.Get(sessionId) + if cancelFunc != nil { + cancelFunc() + h.App.ReqCancelFunc.Delete(sessionId) + } + return + } + + var message types.WsMessage + err = utils.JsonDecode(string(msg), &message) + if err != nil { + continue + } + + // 心跳消息 + if message.Type == "heartbeat" { + logger.Debug("收到 Chat 心跳消息:", message.Content) + continue + } + + logger.Info("Receive a message: ", message.Content) + + ctx, cancel := context.WithCancel(context.Background()) + h.App.ReqCancelFunc.Put(sessionId, cancel) + // 回复消息 + err = h.sendMessage(ctx, session, chatRole, utils.InterfaceToString(message.Content), client) + if err != nil { + logger.Error(err) + utils.ReplyMessage(client, err.Error()) + } else { + utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd}) + logger.Infof("回答完毕: %v", message.Content) + } + + } + }() +} + +func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, ws *types.WsClient) error { + if !h.App.Debug { + defer func() { + if r := recover(); r != nil { + logger.Error("Recover message from error: ", r) + } + }() + } + + var user model.User + res := h.DB.Model(&model.User{}).First(&user, session.UserId) + if res.Error != nil { + return errors.New("未授权用户,您正在进行非法操作!") + } + var userVo vo.User + err := utils.CopyObject(user, &userVo) + userVo.Id = user.Id + if err != nil { + return errors.New("User 对象转换失败," + err.Error()) + } + + if userVo.Status == false { + return errors.New("您的账号已经被禁用,如果疑问,请联系管理员!") + } + + if userVo.Power < session.Model.Power { + return fmt.Errorf("您当前剩余算力 %d 已不足以支付当前模型的单次对话需要消耗的算力 %d,[立即购买](/member)。", userVo.Power, session.Model.Power) + } + + if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() { + return errors.New("您的账号已经过期,请联系管理员!") + } + + // 检查 prompt 长度是否超过了当前模型允许的最大上下文长度 + promptTokens, err := utils.CalcTokens(prompt, session.Model.Value) + if promptTokens > session.Model.MaxContext { + + return errors.New("对话内容超出了当前模型允许的最大上下文长度!") + } + + var req = types.ApiRequest{ + Model: session.Model.Value, + Stream: true, + } + switch session.Model.Platform { + case types.Azure.Value, types.ChatGLM.Value, types.Baidu.Value, types.XunFei.Value: + req.Temperature = session.Model.Temperature + req.MaxTokens = session.Model.MaxTokens + break + case types.OpenAI.Value: + req.Temperature = session.Model.Temperature + req.MaxTokens = session.Model.MaxTokens + // OpenAI 支持函数功能 + var items []model.Function + res := h.DB.Where("enabled", true).Find(&items) + if res.Error != nil { + break + } + + var tools = make([]types.Tool, 0) + for _, v := range items { + var parameters map[string]interface{} + err = utils.JsonDecode(v.Parameters, ¶meters) + if err != nil { + continue + } + tool := types.Tool{ + Type: "function", + Function: types.Function{ + Name: v.Name, + Description: v.Description, + Parameters: parameters, + }, + } + if v, ok := parameters["required"]; v == nil || !ok { + tool.Function.Parameters["required"] = []string{} + } + tools = append(tools, tool) + } + + if len(tools) > 0 { + req.Tools = tools + req.ToolChoice = "auto" + } + case types.QWen.Value: + req.Parameters = map[string]interface{}{ + "max_tokens": session.Model.MaxTokens, + "temperature": session.Model.Temperature, + } + break + + default: + return fmt.Errorf("不支持的平台:%s", session.Model.Platform) + } + + // 加载聊天上下文 + chatCtx := make([]types.Message, 0) + messages := make([]types.Message, 0) + if h.App.SysConfig.EnableContext { + if h.App.ChatContexts.Has(session.ChatId) { + messages = h.App.ChatContexts.Get(session.ChatId) + } else { + _ = utils.JsonDecode(role.Context, &messages) + if h.App.SysConfig.ContextDeep > 0 { + var historyMessages []model.ChatMessage + res := h.DB.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages) + if res.Error == nil { + for i := len(historyMessages) - 1; i >= 0; i-- { + msg := historyMessages[i] + ms := types.Message{Role: "user", Content: msg.Content} + if msg.Type == types.ReplyMsg { + ms.Role = "assistant" + } + chatCtx = append(chatCtx, ms) + } + } + } + } + + // 计算当前请求的 token 总长度,确保不会超出最大上下文长度 + // MaxContextLength = Response + Tool + Prompt + Context + tokens := req.MaxTokens // 最大响应长度 + tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model) + tokens += tks + promptTokens + + for _, v := range messages { + tks, _ := utils.CalcTokens(v.Content, req.Model) + // 上下文 token 超出了模型的最大上下文长度 + if tokens+tks >= session.Model.MaxContext { + break + } + + // 上下文的深度超出了模型的最大上下文深度 + if len(chatCtx) >= h.App.SysConfig.ContextDeep { + break + } + + tokens += tks + chatCtx = append(chatCtx, v) + } + + logger.Debugf("聊天上下文:%+v", chatCtx) + } + reqMgs := make([]interface{}, 0) + for _, m := range chatCtx { + reqMgs = append(reqMgs, m) + } + + if session.Model.Platform == types.QWen.Value { + req.Input = make(map[string]interface{}) + reqMgs = append(reqMgs, types.Message{ + Role: "user", + Content: prompt, + }) + req.Input["messages"] = reqMgs + } else if session.Model.Platform == types.OpenAI.Value { // extract image for gpt-vision model + imgURLs := utils.ExtractImgURL(prompt) + logger.Debugf("detected IMG: %+v", imgURLs) + var content interface{} + if len(imgURLs) > 0 { + data := make([]interface{}, 0) + text := prompt + for _, v := range imgURLs { + text = strings.Replace(text, v, "", 1) + data = append(data, gin.H{ + "type": "image_url", + "image_url": gin.H{ + "url": v, + }, + }) + } + data = append(data, gin.H{ + "type": "text", + "text": text, + }) + content = data + } else { + content = prompt + } + req.Messages = append(reqMgs, map[string]interface{}{ + "role": "user", + "content": content, + }) + } else { + req.Messages = append(reqMgs, map[string]interface{}{ + "role": "user", + "content": prompt, + }) + } + + logger.Debugf("%+v", req.Messages) + + switch session.Model.Platform { + case types.Azure.Value: + return h.sendAzureMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) + case types.OpenAI.Value: + return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) + case types.ChatGLM.Value: + return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) + case types.Baidu.Value: + return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) + case types.XunFei.Value: + return h.sendXunFeiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) + case types.QWen.Value: + return h.sendQWenMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) + } + + return nil +} + +// Tokens 统计 token 数量 +func (h *ChatHandler) Tokens(c *gin.Context) { + var data struct { + Text string `json:"text"` + Model string `json:"model"` + ChatId string `json:"chat_id"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + // 如果没有传入 text 字段,则说明是获取当前 reply 总的 token 消耗(带上下文) + if data.Text == "" && data.ChatId != "" { + var item model.ChatMessage + userId, _ := c.Get(types.LoginUserID) + res := h.DB.Where("user_id = ?", userId).Where("chat_id = ?", data.ChatId).Last(&item) + if res.Error != nil { + resp.ERROR(c, res.Error.Error()) + return + } + resp.SUCCESS(c, item.Tokens) + return + } + + tokens, err := utils.CalcTokens(data.Text, data.Model) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c, tokens) +} + +func getTotalTokens(req types.ApiRequest) int { + encode := utils.JsonEncode(req.Messages) + var items []map[string]interface{} + err := utils.JsonDecode(encode, &items) + if err != nil { + return 0 + } + tokens := 0 + for _, item := range items { + content, ok := item["content"] + if ok && !utils.IsEmptyValue(content) { + t, err := utils.CalcTokens(utils.InterfaceToString(content), req.Model) + if err == nil { + tokens += t + } + } + } + return tokens +} + +// StopGenerate 停止生成 +func (h *ChatHandler) StopGenerate(c *gin.Context) { + sessionId := c.Query("session_id") + if h.App.ReqCancelFunc.Has(sessionId) { + h.App.ReqCancelFunc.Get(sessionId)() + h.App.ReqCancelFunc.Delete(sessionId) + } + resp.SUCCESS(c, types.OkMsg) +} + +// 发送请求到 OpenAI 服务器 +// useOwnApiKey: 是否使用了用户自己的 API KEY +func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, session *types.ChatSession, apiKey *model.ApiKey) (*http.Response, error) { + // if the chat model bind a KEY, use it directly + if session.Model.KeyId > 0 { + h.DB.Debug().Where("id", session.Model.KeyId).Where("enabled", true).Find(apiKey) + } + // use the last unused key + if apiKey.Id == 0 { + h.DB.Where("platform", session.Model.Platform).Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(apiKey) + } + if apiKey.Id == 0 { + return nil, errors.New("no available key, please import key") + } + + // ONLY allow apiURL in blank list + if session.Model.Platform == types.OpenAI.Value { + err := h.licenseService.IsValidApiURL(apiKey.ApiURL) + if err != nil { + return nil, err + } + } + + var apiURL string + switch session.Model.Platform { + case types.Azure.Value: + md := strings.Replace(req.Model, ".", "", 1) + apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1) + break + case types.ChatGLM.Value: + apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1) + req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段 + req.Messages = nil + break + case types.Baidu.Value: + apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1) + break + case types.QWen.Value: + apiURL = apiKey.ApiURL + req.Messages = nil + break + default: + apiURL = apiKey.ApiURL + } + // 更新 API KEY 的最后使用时间 + h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix()) + // 百度文心,需要串接 access_token + if session.Model.Platform == types.Baidu.Value { + token, err := h.getBaiduToken(apiKey.Value) + if err != nil { + return nil, err + } + logger.Info("百度文心 Access_Token:", token) + apiURL = fmt.Sprintf("%s?access_token=%s", apiURL, token) + } + + logger.Debugf(utils.JsonEncode(req)) + + // 创建 HttpClient 请求对象 + var client *http.Client + requestBody, err := json.Marshal(req) + if err != nil { + return nil, err + } + request, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewBuffer(requestBody)) + if err != nil { + return nil, err + } + + request = request.WithContext(ctx) + request.Header.Set("Content-Type", "application/json") + if len(apiKey.ProxyURL) > 5 { // 使用代理 + proxy, _ := url.Parse(apiKey.ProxyURL) + client = &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxy), + }, + } + } else { + client = http.DefaultClient + } + logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model) + switch session.Model.Platform { + case types.Azure.Value: + request.Header.Set("api-key", apiKey.Value) + break + case types.ChatGLM.Value: + token, err := h.getChatGLMToken(apiKey.Value) + if err != nil { + return nil, err + } + request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + break + case types.Baidu.Value: + request.RequestURI = "" + case types.OpenAI.Value: + request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value)) + break + case types.QWen.Value: + request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value)) + request.Header.Set("X-DashScope-SSE", "enable") + break + } + return client.Do(request) +} + +// 扣减用户算力 +func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, promptTokens int, replyTokens int) { + power := 1 + if session.Model.Power > 0 { + power = session.Model.Power + } + res := h.DB.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("power", gorm.Expr("power - ?", power)) + if res.Error == nil { + // 记录算力消费日志 + var u model.User + h.DB.Where("id", userVo.Id).First(&u) + h.DB.Create(&model.PowerLog{ + UserId: userVo.Id, + Username: userVo.Username, + Type: types.PowerConsume, + Amount: power, + Mark: types.PowerSub, + Balance: u.Power, + Model: session.Model.Value, + Remark: fmt.Sprintf("模型名称:%s, 提问长度:%d,回复长度:%d", session.Model.Name, promptTokens, replyTokens), + CreatedAt: time.Now(), + }) + } + +} + +func (h *ChatHandler) saveChatHistory( + req types.ApiRequest, + prompt string, + contents []string, + message types.Message, + chatCtx []types.Message, + session *types.ChatSession, + role model.ChatRole, + userVo vo.User, + promptCreatedAt time.Time, + replyCreatedAt time.Time) { + if message.Role == "" { + message.Role = "assistant" + } + message.Content = strings.Join(contents, "") + useMsg := types.Message{Role: "user", Content: prompt} + + // 更新上下文消息,如果是调用函数则不需要更新上下文 + if h.App.SysConfig.EnableContext { + chatCtx = append(chatCtx, useMsg) // 提问消息 + chatCtx = append(chatCtx, message) // 回复消息 + h.App.ChatContexts.Put(session.ChatId, chatCtx) + } + + // 追加聊天记录 + // for prompt + promptToken, err := utils.CalcTokens(prompt, req.Model) + if err != nil { + logger.Error(err) + } + historyUserMsg := model.ChatMessage{ + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.PromptMsg, + Icon: userVo.Avatar, + Content: template.HTMLEscapeString(prompt), + Tokens: promptToken, + UseContext: true, + Model: req.Model, + } + historyUserMsg.CreatedAt = promptCreatedAt + historyUserMsg.UpdatedAt = promptCreatedAt + res := h.DB.Save(&historyUserMsg) + if res.Error != nil { + logger.Error("failed to save prompt history message: ", res.Error) + } + + // for reply + // 计算本次对话消耗的总 token 数量 + replyTokens, _ := utils.CalcTokens(message.Content, req.Model) + totalTokens := replyTokens + getTotalTokens(req) + historyReplyMsg := model.ChatMessage{ + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.ReplyMsg, + Icon: role.Icon, + Content: message.Content, + Tokens: totalTokens, + UseContext: true, + Model: req.Model, + } + historyReplyMsg.CreatedAt = replyCreatedAt + historyReplyMsg.UpdatedAt = replyCreatedAt + res = h.DB.Create(&historyReplyMsg) + if res.Error != nil { + logger.Error("failed to save reply history message: ", res.Error) + } + + if session.Model.Power > 0 { + // 更新用户算力 + h.subUserPower(userVo, session, promptToken, replyTokens) + + // 保存当前会话 + var chatItem model.ChatItem + res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem) + if res.Error != nil { + chatItem.ChatId = session.ChatId + chatItem.UserId = session.UserId + chatItem.RoleId = role.Id + chatItem.ModelId = session.Model.Id + if utf8.RuneCountInString(prompt) > 30 { + chatItem.Title = string([]rune(prompt)[:30]) + "..." + } else { + chatItem.Title = prompt + } + chatItem.Model = req.Model + h.DB.Create(&chatItem) + } + } +} + +// 将AI回复消息中生成的图片链接下载到本地 +func (h *ChatHandler) extractImgUrl(text string) string { + pattern := `!\[([^\]]*)]\(([^)]+)\)` + re := regexp.MustCompile(pattern) + matches := re.FindAllStringSubmatch(text, -1) + + // 下载图片并替换链接地址 + for _, match := range matches { + imageURL := match[2] + logger.Debug(imageURL) + // 对于相同地址的图片,已经被替换了,就不再重复下载了 + if !strings.Contains(text, imageURL) { + continue + } + + newImgURL, err := h.uploadManager.GetUploadHandler().PutImg(imageURL, false) + if err != nil { + logger.Error("error with download image: ", err) + continue + } + + text = strings.ReplaceAll(text, imageURL, newImgURL) + } + return text +} diff --git a/handler/chatimpl/chat_item_handler.go b/handler/chatimpl/chat_item_handler.go new file mode 100644 index 0000000..3e04bf6 --- /dev/null +++ b/handler/chatimpl/chat_item_handler.go @@ -0,0 +1,213 @@ +package chatimpl + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core/types" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +// List 获取会话列表 +func (h *ChatHandler) List(c *gin.Context) { + if !h.IsLogin(c) { + resp.SUCCESS(c) + return + } + + userId := h.GetLoginUserId(c) + var items = make([]vo.ChatItem, 0) + var chats []model.ChatItem + res := h.DB.Where("user_id = ?", userId).Order("id DESC").Find(&chats) + if res.Error == nil { + var roleIds = make([]uint, 0) + for _, chat := range chats { + roleIds = append(roleIds, chat.RoleId) + } + var roles []model.ChatRole + res = h.DB.Find(&roles, roleIds) + if res.Error == nil { + roleMap := make(map[uint]model.ChatRole) + for _, role := range roles { + roleMap[role.Id] = role + } + + for _, chat := range chats { + var item vo.ChatItem + err := utils.CopyObject(chat, &item) + if err == nil { + item.Id = chat.Id + item.Icon = roleMap[chat.RoleId].Icon + items = append(items, item) + } + } + } + + } + resp.SUCCESS(c, items) +} + +// Update 更新会话标题 +func (h *ChatHandler) Update(c *gin.Context) { + var data struct { + ChatId string `json:"chat_id"` + Title string `json:"title"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + res := h.DB.Model(&model.ChatItem{}).Where("chat_id = ?", data.ChatId).UpdateColumn("title", data.Title) + if res.Error != nil { + resp.ERROR(c, "Failed to update database") + return + } + + resp.SUCCESS(c, types.OkMsg) +} + +// Clear 清空所有聊天记录 +func (h *ChatHandler) Clear(c *gin.Context) { + // 获取当前登录用户所有的聊天会话 + user, err := h.GetLoginUser(c) + if err != nil { + resp.NotAuth(c) + return + } + + var chats []model.ChatItem + res := h.DB.Where("user_id = ?", user.Id).Find(&chats) + if res.Error != nil { + resp.ERROR(c, "No chats found") + return + } + + var chatIds = make([]string, 0) + for _, chat := range chats { + chatIds = append(chatIds, chat.ChatId) + // 清空会话上下文 + h.App.ChatContexts.Delete(chat.ChatId) + } + err = h.DB.Transaction(func(tx *gorm.DB) error { + res := h.DB.Where("user_id =?", user.Id).Delete(&model.ChatItem{}) + if res.Error != nil { + return res.Error + } + + res = h.DB.Where("user_id = ? AND chat_id IN ?", user.Id, chatIds).Delete(&model.ChatMessage{}) + if res.Error != nil { + return res.Error + } + + // TODO: 是否要删除 MidJourney 绘画记录和图片文件? + return nil + }) + + if err != nil { + logger.Errorf("Error with delete chats: %+v", err) + resp.ERROR(c, "Failed to remove chat from database.") + return + } + + resp.SUCCESS(c, types.OkMsg) +} + +// History 获取聊天历史记录 +func (h *ChatHandler) History(c *gin.Context) { + chatId := c.Query("chat_id") // 会话 ID + var items []model.ChatMessage + var messages = make([]vo.HistoryMessage, 0) + res := h.DB.Where("chat_id = ?", chatId).Find(&items) + if res.Error != nil { + resp.ERROR(c, "No history message") + return + } else { + for _, item := range items { + var v vo.HistoryMessage + err := utils.CopyObject(item, &v) + v.CreatedAt = item.CreatedAt.Unix() + v.UpdatedAt = item.UpdatedAt.Unix() + if err == nil { + messages = append(messages, v) + } + } + } + + resp.SUCCESS(c, messages) +} + +// Remove 删除会话 +func (h *ChatHandler) Remove(c *gin.Context) { + chatId := h.GetTrim(c, "chat_id") + if chatId == "" { + resp.ERROR(c, types.InvalidArgs) + return + } + user, err := h.GetLoginUser(c) + if err != nil { + resp.NotAuth(c) + return + } + + res := h.DB.Where("user_id = ? AND chat_id = ?", user.Id, chatId).Delete(&model.ChatItem{}) + if res.Error != nil { + resp.ERROR(c, "Failed to update database") + return + } + + // 删除当前会话的聊天记录 + res = h.DB.Where("user_id = ? AND chat_id =?", user.Id, chatId).Delete(&model.ChatItem{}) + if res.Error != nil { + resp.ERROR(c, "Failed to remove chat from database.") + return + } + + // TODO: 是否要删除 MidJourney 绘画记录和图片文件? + + // 清空会话上下文 + h.App.ChatContexts.Delete(chatId) + resp.SUCCESS(c, types.OkMsg) +} + +// Detail 对话详情,用户导出对话 +func (h *ChatHandler) Detail(c *gin.Context) { + chatId := h.GetTrim(c, "chat_id") + if utils.IsEmptyValue(chatId) { + resp.ERROR(c, "Invalid chatId") + return + } + + var chatItem model.ChatItem + res := h.DB.Where("chat_id = ?", chatId).First(&chatItem) + if res.Error != nil { + resp.ERROR(c, "No chat found") + return + } + + // 填充角色名称 + var role model.ChatRole + res = h.DB.Where("id", chatItem.RoleId).First(&role) + if res.Error != nil { + resp.ERROR(c, "Role not found") + return + } + + var chatItemVo vo.ChatItem + err := utils.CopyObject(chatItem, &chatItemVo) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + chatItemVo.RoleName = role.Name + resp.SUCCESS(c, chatItemVo) +} diff --git a/handler/chatimpl/chatglm_handler.go b/handler/chatimpl/chatglm_handler.go new file mode 100644 index 0000000..0192abc --- /dev/null +++ b/handler/chatimpl/chatglm_handler.go @@ -0,0 +1,142 @@ +package chatimpl + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "bufio" + "context" + "errors" + "fmt" + "geekai/core/types" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "github.com/golang-jwt/jwt/v5" + "io" + "strings" + "time" +) + +// 清华大学 ChatGML 消息发送实现 + +func (h *ChatHandler) sendChatGLMMessage( + chatCtx []types.Message, + req types.ApiRequest, + userVo vo.User, + ctx context.Context, + session *types.ChatSession, + role model.ChatRole, + prompt string, + ws *types.WsClient) error { + promptCreatedAt := time.Now() // 记录提问时间 + start := time.Now() + var apiKey = model.ApiKey{} + response, err := h.doRequest(ctx, req, session, &apiKey) + logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) + if err != nil { + if strings.Contains(err.Error(), "context canceled") { + return fmt.Errorf("用户取消了请求:%s", prompt) + } else if strings.Contains(err.Error(), "no available key") { + return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") + } + return err + } else { + defer response.Body.Close() + } + + contentType := response.Header.Get("Content-Type") + if strings.Contains(contentType, "text/event-stream") { + replyCreatedAt := time.Now() // 记录回复时间 + // 循环读取 Chunk 消息 + var message = types.Message{} + var contents = make([]string, 0) + var event, content string + scanner := bufio.NewScanner(response.Body) + for scanner.Scan() { + line := scanner.Text() + if len(line) < 5 || strings.HasPrefix(line, "id:") { + continue + } + if strings.HasPrefix(line, "event:") { + event = line[6:] + continue + } + + if strings.HasPrefix(line, "data:") { + content = line[5:] + } + // 处理代码换行 + if len(content) == 0 { + content = "\n" + } + switch event { + case "add": + if len(contents) == 0 { + utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) + } + utils.ReplyChunkMessage(ws, types.WsMessage{ + Type: types.WsMiddle, + Content: utils.InterfaceToString(content), + }) + contents = append(contents, content) + case "finish": + break + case "error": + utils.ReplyMessage(ws, fmt.Sprintf("**调用 ChatGLM API 出错:%s**", content)) + break + case "interrupted": + utils.ReplyMessage(ws, "**调用 ChatGLM API 出错,当前输出被中断!**") + } + + } // end for + + if err := scanner.Err(); err != nil { + if strings.Contains(err.Error(), "context canceled") { + logger.Info("用户取消了请求:", prompt) + } else { + logger.Error("信息读取出错:", err) + } + } + + // 消息发送成功 + if len(contents) > 0 { + h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt) + } + } else { + body, _ := io.ReadAll(response.Body) + return fmt.Errorf("请求大模型 API 失败:%s", body) + } + + return nil +} + +func (h *ChatHandler) getChatGLMToken(apiKey string) (string, error) { + ctx := context.Background() + tokenString, err := h.redis.Get(ctx, apiKey).Result() + if err == nil { + return tokenString, nil + } + + expr := time.Hour * 2 + key := strings.Split(apiKey, ".") + if len(key) != 2 { + return "", fmt.Errorf("invalid api key: %s", apiKey) + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "api_key": key[0], + "timestamp": time.Now().Unix(), + "exp": time.Now().Add(expr).Add(time.Second * 10).Unix(), + }) + token.Header["alg"] = "HS256" + token.Header["sign_type"] = "SIGN" + delete(token.Header, "typ") + // Sign and get the complete encoded token as a string using the secret + tokenString, err = token.SignedString([]byte(key[1])) + h.redis.Set(ctx, apiKey, tokenString, expr) + return tokenString, err +} diff --git a/handler/chatimpl/openai_handler.go b/handler/chatimpl/openai_handler.go new file mode 100644 index 0000000..fb953b7 --- /dev/null +++ b/handler/chatimpl/openai_handler.go @@ -0,0 +1,186 @@ +package chatimpl + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "geekai/core/types" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + req2 "github.com/imroc/req/v3" + "io" + "strings" + "time" +) + +// OPenAI 消息发送实现 +func (h *ChatHandler) sendOpenAiMessage( + chatCtx []types.Message, + req types.ApiRequest, + userVo vo.User, + ctx context.Context, + session *types.ChatSession, + role model.ChatRole, + prompt string, + ws *types.WsClient) error { + promptCreatedAt := time.Now() // 记录提问时间 + start := time.Now() + var apiKey = model.ApiKey{} + response, err := h.doRequest(ctx, req, session, &apiKey) + logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) + if err != nil { + if strings.Contains(err.Error(), "context canceled") { + return fmt.Errorf("用户取消了请求:%s", prompt) + } else if strings.Contains(err.Error(), "no available key") { + return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") + } + return err + } else { + defer response.Body.Close() + } + + contentType := response.Header.Get("Content-Type") + if strings.Contains(contentType, "text/event-stream") { + replyCreatedAt := time.Now() // 记录回复时间 + // 循环读取 Chunk 消息 + var message = types.Message{} + var contents = make([]string, 0) + var function model.Function + var toolCall = false + var arguments = make([]string, 0) + scanner := bufio.NewScanner(response.Body) + var isNew = true + for scanner.Scan() { + line := scanner.Text() + if !strings.Contains(line, "data:") || len(line) < 30 { + continue + } + + var responseBody = types.ApiResponse{} + err = json.Unmarshal([]byte(line[6:]), &responseBody) + if err != nil { // 数据解析出错 + return errors.New(line) + } + if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行 + continue + } + + if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 { + utils.ReplyMessage(ws, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。") + break + } + + var tool types.ToolCall + if len(responseBody.Choices[0].Delta.ToolCalls) > 0 { + tool = responseBody.Choices[0].Delta.ToolCalls[0] + if toolCall && tool.Function.Name == "" { + arguments = append(arguments, tool.Function.Arguments) + continue + } + } + + // 兼容 Function Call + fun := responseBody.Choices[0].Delta.FunctionCall + if fun.Name != "" { + tool = *new(types.ToolCall) + tool.Function.Name = fun.Name + } else if toolCall { + arguments = append(arguments, fun.Arguments) + continue + } + + if !utils.IsEmptyValue(tool) { + res := h.DB.Where("name = ?", tool.Function.Name).First(&function) + if res.Error == nil { + toolCall = true + callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label) + utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) + utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: callMsg}) + contents = append(contents, callMsg) + } + continue + } + + if responseBody.Choices[0].FinishReason == "tool_calls" || + responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕 + break + } + + // output stopped + if responseBody.Choices[0].FinishReason != "" { + break // 输出完成或者输出中断了 + } else { + content := responseBody.Choices[0].Delta.Content + contents = append(contents, utils.InterfaceToString(content)) + if isNew { + utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) + isNew = false + } + utils.ReplyChunkMessage(ws, types.WsMessage{ + Type: types.WsMiddle, + Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), + }) + } + } // end for + + if err := scanner.Err(); err != nil { + if strings.Contains(err.Error(), "context canceled") { + logger.Info("用户取消了请求:", prompt) + } else { + logger.Error("信息读取出错:", err) + } + } + + if toolCall { // 调用函数完成任务 + var params map[string]interface{} + _ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms) + logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params) + params["user_id"] = userVo.Id + var apiRes types.BizVo + r, err := req2.C().R().SetHeader("Content-Type", "application/json"). + SetHeader("Authorization", function.Token). + SetBody(params). + SetSuccessResult(&apiRes).Post(function.Action) + errMsg := "" + if err != nil { + errMsg = err.Error() + } else if r.IsErrorState() { + errMsg = r.Status + } + if errMsg != "" || apiRes.Code != types.Success { + msg := "调用函数工具出错:" + apiRes.Message + errMsg + utils.ReplyChunkMessage(ws, types.WsMessage{ + Type: types.WsMiddle, + Content: msg, + }) + contents = append(contents, msg) + } else { + utils.ReplyChunkMessage(ws, types.WsMessage{ + Type: types.WsMiddle, + Content: apiRes.Data, + }) + contents = append(contents, utils.InterfaceToString(apiRes.Data)) + } + } + + // 消息发送成功 + if len(contents) > 0 { + h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt) + } + } else { + body, _ := io.ReadAll(response.Body) + return fmt.Errorf("请求 OpenAI API 失败:%s", body) + } + + return nil +} diff --git a/handler/chatimpl/qwen_handler.go b/handler/chatimpl/qwen_handler.go new file mode 100644 index 0000000..28bf66b --- /dev/null +++ b/handler/chatimpl/qwen_handler.go @@ -0,0 +1,150 @@ +package chatimpl + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "bufio" + "context" + "fmt" + "geekai/core/types" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "github.com/syndtr/goleveldb/leveldb/errors" + "io" + "strings" + "time" +) + +type qWenResp struct { + Output struct { + FinishReason string `json:"finish_reason"` + Text string `json:"text"` + } `json:"output,omitempty"` + Usage struct { + TotalTokens int `json:"total_tokens"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + } `json:"usage,omitempty"` + RequestID string `json:"request_id"` + + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` +} + +// 通义千问消息发送实现 +func (h *ChatHandler) sendQWenMessage( + chatCtx []types.Message, + req types.ApiRequest, + userVo vo.User, + ctx context.Context, + session *types.ChatSession, + role model.ChatRole, + prompt string, + ws *types.WsClient) error { + promptCreatedAt := time.Now() // 记录提问时间 + start := time.Now() + var apiKey = model.ApiKey{} + response, err := h.doRequest(ctx, req, session, &apiKey) + logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) + if err != nil { + if strings.Contains(err.Error(), "context canceled") { + return fmt.Errorf("用户取消了请求:%s", prompt) + } else if strings.Contains(err.Error(), "no available key") { + return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") + } + return err + } else { + defer response.Body.Close() + } + contentType := response.Header.Get("Content-Type") + if strings.Contains(contentType, "text/event-stream") { + replyCreatedAt := time.Now() // 记录回复时间 + // 循环读取 Chunk 消息 + var message = types.Message{} + var contents = make([]string, 0) + scanner := bufio.NewScanner(response.Body) + + var content, lastText, newText string + var outPutStart = false + + for scanner.Scan() { + line := scanner.Text() + if len(line) < 5 || strings.HasPrefix(line, "id:") || + strings.HasPrefix(line, "event:") || strings.HasPrefix(line, ":HTTP_STATUS/200") { + continue + } + + if !strings.HasPrefix(line, "data:") { + continue + } + + content = line[5:] + var resp qWenResp + if len(contents) == 0 { // 发送消息头 + if !outPutStart { + utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) + outPutStart = true + continue + } else { + // 处理代码换行 + content = "\n" + } + } else { + err := utils.JsonDecode(content, &resp) + if err != nil { + logger.Error("error with parse data line: ", content) + utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err)) + break + } + if resp.Message != "" { + utils.ReplyMessage(ws, fmt.Sprintf("**API 返回错误:%s**", resp.Message)) + break + } + } + + //通过比较 lastText(上一次的文本)和 currentText(当前的文本), + //提取出新添加的文本部分。然后只将这部分新文本发送到客户端。 + //每次循环结束后,lastText 会更新为当前的完整文本,以便于下一次循环进行比较。 + currentText := resp.Output.Text + if currentText != lastText { + // 提取新增文本 + newText = strings.Replace(currentText, lastText, "", 1) + utils.ReplyChunkMessage(ws, types.WsMessage{ + Type: types.WsMiddle, + Content: utils.InterfaceToString(newText), + }) + lastText = currentText // 更新 lastText + } + contents = append(contents, newText) + + if resp.Output.FinishReason == "stop" { + break + } + + } //end for + + if err := scanner.Err(); err != nil { + if strings.Contains(err.Error(), "context canceled") { + logger.Info("用户取消了请求:", prompt) + } else { + logger.Error("信息读取出错:", err) + } + } + + // 消息发送成功 + if len(contents) > 0 { + h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt) + } + } else { + body, _ := io.ReadAll(response.Body) + return fmt.Errorf("请求大模型 API 失败:%s", body) + } + + return nil +} diff --git a/handler/chatimpl/xunfei_handler.go b/handler/chatimpl/xunfei_handler.go new file mode 100644 index 0000000..e4a081f --- /dev/null +++ b/handler/chatimpl/xunfei_handler.go @@ -0,0 +1,255 @@ +package chatimpl + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "geekai/core/types" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "github.com/gorilla/websocket" + "gorm.io/gorm" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +type xunFeiResp struct { + Header struct { + Code int `json:"code"` + Message string `json:"message"` + Sid string `json:"sid"` + Status int `json:"status"` + } `json:"header"` + Payload struct { + Choices struct { + Status int `json:"status"` + Seq int `json:"seq"` + Text []struct { + Content string `json:"content"` + Role string `json:"role"` + Index int `json:"index"` + } `json:"text"` + } `json:"choices"` + Usage struct { + Text struct { + QuestionTokens int `json:"question_tokens"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"text"` + } `json:"usage"` + } `json:"payload"` +} + +var Model2URL = map[string]string{ + "general": "v1.1", + "generalv2": "v2.1", + "generalv3": "v3.1", + "generalv3.5": "v3.5", +} + +// 科大讯飞消息发送实现 + +func (h *ChatHandler) sendXunFeiMessage( + chatCtx []types.Message, + req types.ApiRequest, + userVo vo.User, + ctx context.Context, + session *types.ChatSession, + role model.ChatRole, + prompt string, + ws *types.WsClient) error { + promptCreatedAt := time.Now() // 记录提问时间 + var apiKey model.ApiKey + var res *gorm.DB + // use the bind key + if session.Model.KeyId > 0 { + res = h.DB.Where("id", session.Model.KeyId).Where("enabled", true).Find(&apiKey) + } + // use the last unused key + if apiKey.Id == 0 { + res = h.DB.Where("platform", session.Model.Platform).Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(&apiKey) + } + if res.Error != nil { + return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") + } + // 更新 API KEY 的最后使用时间 + h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix()) + + d := websocket.Dialer{ + HandshakeTimeout: 5 * time.Second, + } + key := strings.Split(apiKey.Value, "|") + if len(key) != 3 { + utils.ReplyMessage(ws, "非法的 API KEY!") + return nil + } + + apiURL := strings.Replace(apiKey.ApiURL, "{version}", Model2URL[req.Model], 1) + logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model) + wsURL, err := assembleAuthUrl(apiURL, key[1], key[2]) + //握手并建立websocket 连接 + conn, resp, err := d.Dial(wsURL, nil) + if err != nil { + logger.Error(readResp(resp) + err.Error()) + utils.ReplyMessage(ws, "请求讯飞星火模型 API 失败:"+readResp(resp)+err.Error()) + return nil + } else if resp.StatusCode != 101 { + utils.ReplyMessage(ws, "请求讯飞星火模型 API 失败:"+readResp(resp)+err.Error()) + return nil + } + + data := buildRequest(key[0], req) + fmt.Printf("%+v", data) + fmt.Println(apiURL) + err = conn.WriteJSON(data) + if err != nil { + utils.ReplyMessage(ws, "发送消息失败:"+err.Error()) + return nil + } + + replyCreatedAt := time.Now() // 记录回复时间 + // 循环读取 Chunk 消息 + var message = types.Message{} + var contents = make([]string, 0) + var content string + for { + _, msg, err := conn.ReadMessage() + if err != nil { + logger.Error("error with read message:", err) + utils.ReplyMessage(ws, fmt.Sprintf("**数据读取失败:%s**", err)) + break + } + + // 解析数据 + var result xunFeiResp + err = json.Unmarshal(msg, &result) + if err != nil { + logger.Error("error with parsing JSON:", err) + utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err)) + return nil + } + + if result.Header.Code != 0 { + utils.ReplyMessage(ws, fmt.Sprintf("**请求 API 返回错误:%s**", result.Header.Message)) + return nil + } + + content = result.Payload.Choices.Text[0].Content + // 处理代码换行 + if len(content) == 0 { + content = "\n" + } + contents = append(contents, content) + // 第一个结果 + if result.Payload.Choices.Status == 0 { + utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) + } + utils.ReplyChunkMessage(ws, types.WsMessage{ + Type: types.WsMiddle, + Content: utils.InterfaceToString(content), + }) + + if result.Payload.Choices.Status == 2 { // 最终结果 + _ = conn.Close() // 关闭连接 + break + } + + select { + case <-ctx.Done(): + utils.ReplyMessage(ws, "**用户取消了生成指令!**") + return nil + default: + continue + } + + } + // 消息发送成功 + if len(contents) > 0 { + h.saveChatHistory(req, prompt, contents, message, chatCtx, session, role, userVo, promptCreatedAt, replyCreatedAt) + } + return nil +} + +// 构建 websocket 请求实体 +func buildRequest(appid string, req types.ApiRequest) map[string]interface{} { + return map[string]interface{}{ + "header": map[string]interface{}{ + "app_id": appid, + }, + "parameter": map[string]interface{}{ + "chat": map[string]interface{}{ + "domain": req.Model, + "temperature": req.Temperature, + "top_k": int64(6), + "max_tokens": int64(req.MaxTokens), + "auditing": "default", + }, + }, + "payload": map[string]interface{}{ + "message": map[string]interface{}{ + "text": req.Messages, + }, + }, + } +} + +// 创建鉴权 URL +func assembleAuthUrl(hostURL string, apiKey, apiSecret string) (string, error) { + ul, err := url.Parse(hostURL) + if err != nil { + return "", err + } + + date := time.Now().UTC().Format(time.RFC1123) + signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"} + //拼接签名字符串 + signStr := strings.Join(signString, "\n") + sha := hmacWithSha256(signStr, apiSecret) + + authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, + "hmac-sha256", "host date request-line", sha) + //将请求参数使用base64编码 + authorization := base64.StdEncoding.EncodeToString([]byte(authUrl)) + v := url.Values{} + v.Add("host", ul.Host) + v.Add("date", date) + v.Add("authorization", authorization) + //将编码后的字符串url encode后添加到url后面 + return hostURL + "?" + v.Encode(), nil +} + +// 使用 sha256 签名 +func hmacWithSha256(data, key string) string { + mac := hmac.New(sha256.New, []byte(key)) + mac.Write([]byte(data)) + encodeData := mac.Sum(nil) + return base64.StdEncoding.EncodeToString(encodeData) +} + +// 读取响应 +func readResp(resp *http.Response) string { + if resp == nil { + return "" + } + b, err := io.ReadAll(resp.Body) + if err != nil { + panic(err) + } + return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b)) +} diff --git a/handler/config_handler.go b/handler/config_handler.go new file mode 100644 index 0000000..30e33b8 --- /dev/null +++ b/handler/config_handler.go @@ -0,0 +1,54 @@ +package handler + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/service" + "geekai/store/model" + "geekai/utils" + "geekai/utils/resp" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +type ConfigHandler struct { + BaseHandler + licenseService *service.LicenseService +} + +func NewConfigHandler(app *core.AppServer, db *gorm.DB, licenseService *service.LicenseService) *ConfigHandler { + return &ConfigHandler{BaseHandler: BaseHandler{App: app, DB: db}, licenseService: licenseService} +} + +// Get 获取指定的系统配置 +func (h *ConfigHandler) Get(c *gin.Context) { + key := c.Query("key") + var config model.Config + res := h.DB.Where("marker", key).First(&config) + if res.Error != nil { + resp.ERROR(c, res.Error.Error()) + return + } + + var value map[string]interface{} + err := utils.JsonDecode(config.Config, &value) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c, value) +} + +// License 获取 License 配置 +func (h *ConfigHandler) License(c *gin.Context) { + license := h.licenseService.GetLicense() + resp.SUCCESS(c, license.Configs) +} diff --git a/handler/dalle_handler.go b/handler/dalle_handler.go new file mode 100644 index 0000000..07cd032 --- /dev/null +++ b/handler/dalle_handler.go @@ -0,0 +1,262 @@ +package handler + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/core/types" + "geekai/service/dalle" + "geekai/service/oss" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + "net/http" + + "github.com/gorilla/websocket" + + "github.com/gin-gonic/gin" + "github.com/go-redis/redis/v8" + "gorm.io/gorm" +) + +type DallJobHandler struct { + BaseHandler + redis *redis.Client + service *dalle.Service + uploader *oss.UploaderManager +} + +func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager) *DallJobHandler { + return &DallJobHandler{ + service: service, + uploader: manager, + BaseHandler: BaseHandler{ + App: app, + DB: db, + }, + } +} + +// Client WebSocket 客户端,用于通知任务状态变更 +func (h *DallJobHandler) Client(c *gin.Context) { + ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) + if err != nil { + logger.Error(err) + c.Abort() + return + } + + userId := h.GetInt(c, "user_id", 0) + if userId == 0 { + logger.Info("Invalid user ID") + c.Abort() + return + } + + client := types.NewWsClient(ws) + h.service.Clients.Put(uint(userId), client) + logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) + go func() { + for { + _, msg, err := client.Receive() + if err != nil { + client.Close() + h.service.Clients.Delete(uint(userId)) + return + } + + var message types.WsMessage + err = utils.JsonDecode(string(msg), &message) + if err != nil { + continue + } + + // 心跳消息 + if message.Type == "heartbeat" { + logger.Debug("收到 DallE 心跳消息:", message.Content) + continue + } + } + }() +} + +func (h *DallJobHandler) preCheck(c *gin.Context) bool { + user, err := h.GetLoginUser(c) + if err != nil { + resp.NotAuth(c) + return false + } + if user.Power < h.App.SysConfig.DallPower { + resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!") + return false + } + + return true + +} + +// Image 创建一个绘画任务 +func (h *DallJobHandler) Image(c *gin.Context) { + if !h.preCheck(c) { + return + } + + var data types.DallTask + if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" { + resp.ERROR(c, types.InvalidArgs) + return + } + + idValue, _ := c.Get(types.LoginUserID) + userId := utils.IntValue(utils.InterfaceToString(idValue), 0) + job := model.DallJob{ + UserId: uint(userId), + Prompt: data.Prompt, + Power: h.App.SysConfig.DallPower, + } + res := h.DB.Create(&job) + if res.Error != nil { + resp.ERROR(c, "error with save job: "+res.Error.Error()) + return + } + + h.service.PushTask(types.DallTask{ + JobId: job.Id, + UserId: uint(userId), + Prompt: data.Prompt, + Quality: data.Quality, + Size: data.Size, + Style: data.Style, + Power: job.Power, + }) + + client := h.service.Clients.Get(job.UserId) + if client != nil { + _ = client.Send([]byte("Task Updated")) + } + resp.SUCCESS(c) +} + +// ImgWall 照片墙 +func (h *DallJobHandler) ImgWall(c *gin.Context) { + page := h.GetInt(c, "page", 0) + pageSize := h.GetInt(c, "page_size", 0) + err, jobs := h.getData(true, 0, page, pageSize, true) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c, jobs) +} + +// JobList 获取 SD 任务列表 +func (h *DallJobHandler) JobList(c *gin.Context) { + status := h.GetBool(c, "status") + userId := h.GetLoginUserId(c) + page := h.GetInt(c, "page", 0) + pageSize := h.GetInt(c, "page_size", 0) + publish := h.GetBool(c, "publish") + + err, jobs := h.getData(status, userId, page, pageSize, publish) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c, jobs) +} + +// JobList 获取任务列表 +func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.DallJob) { + + session := h.DB.Session(&gorm.Session{}) + if finish { + session = session.Where("progress = ?", 100).Order("id DESC") + } else { + session = session.Where("progress < ?", 100).Order("id ASC") + } + if userId > 0 { + session = session.Where("user_id = ?", userId) + } + if publish { + session = session.Where("publish", publish) + } + if page > 0 && pageSize > 0 { + offset := (page - 1) * pageSize + session = session.Offset(offset).Limit(pageSize) + } + + var items []model.DallJob + res := session.Find(&items) + if res.Error != nil { + return res.Error, nil + } + + var jobs = make([]vo.DallJob, 0) + for _, item := range items { + var job vo.DallJob + err := utils.CopyObject(item, &job) + if err != nil { + continue + } + jobs = append(jobs, job) + } + + return nil, jobs +} + +// Remove remove task image +func (h *DallJobHandler) Remove(c *gin.Context) { + var data struct { + Id uint `json:"id"` + UserId uint `json:"user_id"` + ImgURL string `json:"img_url"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + // remove job recode + res := h.DB.Delete(&model.DallJob{Id: data.Id}) + if res.Error != nil { + resp.ERROR(c, res.Error.Error()) + return + } + + // remove image + err := h.uploader.GetUploadHandler().Delete(data.ImgURL) + if err != nil { + logger.Error("remove image failed: ", err) + } + + resp.SUCCESS(c) +} + +// Publish 发布/取消发布图片到画廊显示 +func (h *DallJobHandler) Publish(c *gin.Context) { + var data struct { + Id uint `json:"id"` + Action bool `json:"action"` // 发布动作,true => 发布,false => 取消分享 + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + res := h.DB.Model(&model.DallJob{Id: data.Id}).UpdateColumn("publish", true) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败") + return + } + + resp.SUCCESS(c) +} diff --git a/handler/function_handler.go b/handler/function_handler.go new file mode 100644 index 0000000..6917efd --- /dev/null +++ b/handler/function_handler.go @@ -0,0 +1,226 @@ +package handler + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/core/types" + "geekai/service/dalle" + "geekai/service/oss" + "geekai/store/model" + "geekai/utils" + "geekai/utils/resp" + "errors" + "fmt" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" + "github.com/imroc/req/v3" + "gorm.io/gorm" +) + +type FunctionHandler struct { + BaseHandler + config types.ApiConfig + uploadManager *oss.UploaderManager + dallService *dalle.Service +} + +func NewFunctionHandler( + server *core.AppServer, + db *gorm.DB, + config *types.AppConfig, + manager *oss.UploaderManager, + dallService *dalle.Service) *FunctionHandler { + return &FunctionHandler{ + BaseHandler: BaseHandler{ + App: server, + DB: db, + }, + config: config.ApiConfig, + uploadManager: manager, + dallService: dallService, + } +} + +type resVo struct { + Code types.BizCode `json:"code"` + Message string `json:"message"` + Data struct { + Title string `json:"title"` + UpdatedAt string `json:"updated_at"` + Items []dataItem `json:"items"` + } `json:"data"` +} + +type dataItem struct { + Title string `json:"title"` + Url string `json:"url"` + Remark string `json:"remark"` +} + +// check authorization +func (h *FunctionHandler) checkAuth(c *gin.Context) error { + tokenString := c.GetHeader(types.UserAuthHeader) + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + + return []byte(h.App.Config.Session.SecretKey), nil + }) + + if err != nil { + return fmt.Errorf("error with parse auth token: %v", err) + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || !token.Valid { + return errors.New("token is invalid") + } + + expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0) + if expr > 0 && int64(expr) < time.Now().Unix() { + return errors.New("token is expired") + } + + return nil +} + +// WeiBo 微博热搜 +func (h *FunctionHandler) WeiBo(c *gin.Context) { + if err := h.checkAuth(c); err != nil { + resp.ERROR(c, err.Error()) + return + } + + if h.config.Token == "" { + resp.ERROR(c, "无效的 API Token") + return + } + + url := fmt.Sprintf("%s/api/weibo/fetch", h.config.ApiURL) + var res resVo + r, err := req.C().R(). + SetHeader("AppId", h.config.AppId). + SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)). + SetSuccessResult(&res).Get(url) + if err != nil || r.IsErrorState() { + resp.ERROR(c, fmt.Sprintf("%v%v", err, r.Err)) + return + } + + if res.Code != types.Success { + resp.ERROR(c, res.Message) + return + } + + builder := make([]string, 0) + builder = append(builder, fmt.Sprintf("**%s**,最新更新:%s", res.Data.Title, res.Data.UpdatedAt)) + for i, v := range res.Data.Items { + builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [热度:%s]", i+1, v.Title, v.Url, v.Remark)) + } + resp.SUCCESS(c, strings.Join(builder, "\n\n")) +} + +// ZaoBao 今日早报 +func (h *FunctionHandler) ZaoBao(c *gin.Context) { + if err := h.checkAuth(c); err != nil { + resp.ERROR(c, err.Error()) + return + } + + if h.config.Token == "" { + resp.ERROR(c, "无效的 API Token") + return + } + + url := fmt.Sprintf("%s/api/zaobao/fetch", h.config.ApiURL) + var res resVo + r, err := req.C().R(). + SetHeader("AppId", h.config.AppId). + SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)). + SetSuccessResult(&res).Get(url) + if err != nil || r.IsErrorState() { + resp.ERROR(c, fmt.Sprintf("%v%v", err, r.Err)) + return + } + + if res.Code != types.Success { + resp.ERROR(c, res.Message) + return + } + + builder := make([]string, 0) + builder = append(builder, fmt.Sprintf("**%s 早报:**", res.Data.UpdatedAt)) + for _, v := range res.Data.Items { + builder = append(builder, v.Title) + } + builder = append(builder, fmt.Sprintf("%s", res.Data.Title)) + resp.SUCCESS(c, strings.Join(builder, "\n\n")) +} + +// Dall3 DallE3 AI 绘图 +func (h *FunctionHandler) Dall3(c *gin.Context) { + if err := h.checkAuth(c); err != nil { + resp.ERROR(c, err.Error()) + return + } + + var params map[string]interface{} + if err := c.ShouldBindJSON(¶ms); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + logger.Debugf("绘画参数:%+v", params) + var user model.User + res := h.DB.Where("id = ?", params["user_id"]).First(&user) + if res.Error != nil { + resp.ERROR(c, "当前用户不存在!") + return + } + + if user.Power < h.App.SysConfig.DallPower { + resp.ERROR(c, "创建 DALL-E 绘图任务失败,算力不足") + return + } + + // create dall task + prompt := utils.InterfaceToString(params["prompt"]) + job := model.DallJob{ + UserId: user.Id, + Prompt: prompt, + Power: h.App.SysConfig.DallPower, + } + res = h.DB.Create(&job) + + if res.Error != nil { + resp.ERROR(c, "创建 DALL-E 绘图任务失败:"+res.Error.Error()) + return + } + + content, err := h.dallService.Image(types.DallTask{ + JobId: job.Id, + UserId: user.Id, + Prompt: job.Prompt, + N: 1, + Quality: "standard", + Size: "1024x1024", + Style: "vivid", + Power: job.Power, + }, true) + if err != nil { + resp.ERROR(c, "任务执行失败:"+err.Error()) + return + } + + resp.SUCCESS(c, content) +} diff --git a/handler/invite_handler.go b/handler/invite_handler.go new file mode 100644 index 0000000..3e4fdb5 --- /dev/null +++ b/handler/invite_handler.go @@ -0,0 +1,100 @@ +package handler + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/core/types" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + "github.com/gin-gonic/gin" + "gorm.io/gorm" + "strings" +) + +// InviteHandler 用户邀请 +type InviteHandler struct { + BaseHandler +} + +func NewInviteHandler(app *core.AppServer, db *gorm.DB) *InviteHandler { + return &InviteHandler{BaseHandler: BaseHandler{App: app, DB: db}} +} + +// Code 获取当前用户邀请码 +func (h *InviteHandler) Code(c *gin.Context) { + userId := h.GetLoginUserId(c) + var inviteCode model.InviteCode + res := h.DB.Where("user_id = ?", userId).First(&inviteCode) + // 如果邀请码不存在,则创建一个 + if res.Error != nil { + code := strings.ToUpper(utils.RandString(8)) + for { + res = h.DB.Where("code = ?", code).First(&inviteCode) + if res.Error != nil { // 不存在相同的邀请码则退出 + break + } + } + inviteCode.UserId = userId + inviteCode.Code = code + h.DB.Create(&inviteCode) + } + + var codeVo vo.InviteCode + err := utils.CopyObject(inviteCode, &codeVo) + if err != nil { + resp.ERROR(c, "拷贝对象失败") + return + } + + resp.SUCCESS(c, codeVo) +} + +// List Log 用户邀请记录 +func (h *InviteHandler) List(c *gin.Context) { + + var data struct { + Page int `json:"page"` + PageSize int `json:"page_size"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + userId := h.GetLoginUserId(c) + session := h.DB.Session(&gorm.Session{}).Where("inviter_id = ?", userId) + var total int64 + session.Model(&model.InviteLog{}).Count(&total) + var items []model.InviteLog + var list = make([]vo.InviteLog, 0) + offset := (data.Page - 1) * data.PageSize + res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items) + if res.Error == nil { + for _, item := range items { + var v vo.InviteLog + err := utils.CopyObject(item, &v) + if err == nil { + v.Id = item.Id + v.CreatedAt = item.CreatedAt.Unix() + list = append(list, v) + } else { + logger.Error(err) + } + } + } + resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list)) +} + +// Hits 访问邀请码 +func (h *InviteHandler) Hits(c *gin.Context) { + code := c.Query("code") + h.DB.Model(&model.InviteCode{}).Where("code = ?", code).UpdateColumn("hits", gorm.Expr("hits + ?", 1)) + resp.SUCCESS(c) +} diff --git a/handler/markmap_handler.go b/handler/markmap_handler.go new file mode 100644 index 0000000..bf67ab7 --- /dev/null +++ b/handler/markmap_handler.go @@ -0,0 +1,273 @@ +package handler + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "bufio" + "bytes" + "encoding/json" + "errors" + "fmt" + "geekai/core" + "geekai/core/types" + "geekai/store/model" + "geekai/utils" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "gorm.io/gorm" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +// MarkMapHandler 生成思维导图 +type MarkMapHandler struct { + BaseHandler + clients *types.LMap[int, *types.WsClient] +} + +func NewMarkMapHandler(app *core.AppServer, db *gorm.DB) *MarkMapHandler { + return &MarkMapHandler{ + BaseHandler: BaseHandler{App: app, DB: db}, + clients: types.NewLMap[int, *types.WsClient](), + } +} + +func (h *MarkMapHandler) Client(c *gin.Context) { + ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) + if err != nil { + logger.Error(err) + return + } + + modelId := h.GetInt(c, "model_id", 0) + userId := h.GetInt(c, "user_id", 0) + + client := types.NewWsClient(ws) + h.clients.Put(userId, client) + go func() { + for { + _, msg, err := client.Receive() + if err != nil { + client.Close() + h.clients.Delete(userId) + return + } + + var message types.WsMessage + err = utils.JsonDecode(string(msg), &message) + if err != nil { + continue + } + + // 心跳消息 + if message.Type == "heartbeat" { + logger.Debug("收到 MarkMap 心跳消息:", message.Content) + continue + } + // change model + if message.Type == "model_id" { + modelId = utils.IntValue(utils.InterfaceToString(message.Content), 0) + continue + } + + logger.Info("Receive a message: ", message.Content) + err = h.sendMessage(client, utils.InterfaceToString(message.Content), modelId, userId) + if err != nil { + logger.Error(err) + utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsErr, Content: err.Error()}) + } + + } + }() +} + +func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, modelId int, userId int) error { + var user model.User + res := h.DB.Model(&model.User{}).First(&user, userId) + if res.Error != nil { + return fmt.Errorf("error with query user info: %v", res.Error) + } + var chatModel model.ChatModel + res = h.DB.Where("id", modelId).First(&chatModel) + if res.Error != nil { + return fmt.Errorf("error with query chat model: %v", res.Error) + } + + if user.Status == false { + return errors.New("当前用户被禁用") + } + + if user.Power < chatModel.Power { + return fmt.Errorf("您当前剩余算力(%d)已不足以支付当前模型算力(%d)!", user.Power, chatModel.Power) + } + + messages := make([]interface{}, 0) + messages = append(messages, types.Message{Role: "system", Content: ` +你是一位非常优秀的思维导图助手,你会把用户的所有提问都总结成思维导图,然后以 Markdown 格式输出。markdown 只需要输出一级标题,二级标题,三级标题,四级标题,最多输出四级,除此之外不要输出任何其他 markdown 标记。下面是一个合格的例子: +# Geek-AI 助手 + +## 完整的开源系统 +### 前端开源 +### 后端开源 + +## 支持各种大模型 +### OpenAI +### Azure +### 文心一言 +### 通义千问 + +## 集成多种收费方式 +### 支付宝 +### 微信 + +另外,除此之外不要任何解释性语句。 +`}) + messages = append(messages, types.Message{Role: "user", Content: prompt}) + var req = types.ApiRequest{ + Model: chatModel.Value, + Stream: true, + Messages: messages, + } + + var apiKey model.ApiKey + response, err := h.doRequest(req, chatModel, &apiKey) + if err != nil { + return fmt.Errorf("请求 OpenAI API 失败: %s", err) + } + + defer response.Body.Close() + + contentType := response.Header.Get("Content-Type") + if strings.Contains(contentType, "text/event-stream") { + // 循环读取 Chunk 消息 + scanner := bufio.NewScanner(response.Body) + var isNew = true + for scanner.Scan() { + line := scanner.Text() + if !strings.Contains(line, "data:") || len(line) < 30 { + continue + } + + var responseBody = types.ApiResponse{} + err = json.Unmarshal([]byte(line[6:]), &responseBody) + if err != nil { // 数据解析出错 + return fmt.Errorf("error with decode data: %v", line) + } + + if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行 + continue + } + + if responseBody.Choices[0].FinishReason == "stop" { + break + } + + if isNew { + utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsStart}) + isNew = false + } + utils.ReplyChunkMessage(client, types.WsMessage{ + Type: types.WsMiddle, + Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), + }) + } // end for + + utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd}) + + } else { + body, err := io.ReadAll(response.Body) + if err != nil { + return fmt.Errorf("读取响应失败: %v", err) + } + var res types.ApiError + err = json.Unmarshal(body, &res) + if err != nil { + return fmt.Errorf("解析响应失败: %v", err) + } + + // OpenAI API 调用异常处理 + if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") { + // remove key + h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{}) + return errors.New("请求 OpenAI API 失败:API KEY 所关联的账户被禁用。") + } else if strings.Contains(res.Error.Message, "You exceeded your current quota") { + return errors.New("请求 OpenAI API 失败:API KEY 触发并发限制,请稍后再试。") + } else { + return fmt.Errorf("请求 OpenAI API 失败:%v", res.Error.Message) + } + } + + // 扣减算力 + res = h.DB.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power - ?", chatModel.Power)) + if res.Error == nil { + // 记录算力消费日志 + var u model.User + h.DB.Where("id", userId).First(&u) + h.DB.Create(&model.PowerLog{ + UserId: u.Id, + Username: u.Username, + Type: types.PowerConsume, + Amount: chatModel.Power, + Mark: types.PowerSub, + Balance: u.Power, + Model: chatModel.Value, + Remark: fmt.Sprintf("AI绘制思维导图,模型名称:%s, ", chatModel.Value), + CreatedAt: time.Now(), + }) + } + + return nil +} + +func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatModel, apiKey *model.ApiKey) (*http.Response, error) { + // if the chat model bind a KEY, use it directly + var res *gorm.DB + if chatModel.KeyId > 0 { + res = h.DB.Where("id", chatModel.KeyId).Where("enabled", true).Find(apiKey) + } + // use the last unused key + if apiKey.Id == 0 { + res = h.DB.Where("platform", types.OpenAI). + Where("type", "chat"). + Where("enabled", true).Order("last_used_at ASC").First(apiKey) + } + if res.Error != nil { + return nil, errors.New("no available key, please import key") + } + apiURL := apiKey.ApiURL + // 更新 API KEY 的最后使用时间 + h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix()) + + // 创建 HttpClient 请求对象 + var client *http.Client + requestBody, err := json.Marshal(req) + if err != nil { + return nil, err + } + request, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewBuffer(requestBody)) + if err != nil { + return nil, err + } + + request.Header.Set("Content-Type", "application/json") + if len(apiKey.ProxyURL) > 5 { // 使用代理 + proxy, _ := url.Parse(apiKey.ProxyURL) + client = &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxy), + }, + } + } else { + client = http.DefaultClient + } + request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value)) + return client.Do(request) +} diff --git a/handler/menu_handler.go b/handler/menu_handler.go new file mode 100644 index 0000000..647ed1e --- /dev/null +++ b/handler/menu_handler.go @@ -0,0 +1,43 @@ +package handler + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +type MenuHandler struct { + BaseHandler +} + +func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler { + return &MenuHandler{BaseHandler: BaseHandler{App: app, DB: db}} +} + +// List 数据列表 +func (h *MenuHandler) List(c *gin.Context) { + var items []model.Menu + var list = make([]vo.Menu, 0) + res := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items) + if res.Error == nil { + for _, item := range items { + var product vo.Menu + err := utils.CopyObject(item, &product) + if err == nil { + list = append(list, product) + } + } + } + resp.SUCCESS(c, list) +} diff --git a/handler/mj_handler.go b/handler/mj_handler.go new file mode 100644 index 0000000..822df42 --- /dev/null +++ b/handler/mj_handler.go @@ -0,0 +1,520 @@ +package handler + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "encoding/base64" + "fmt" + "geekai/core" + "geekai/core/types" + "geekai/service" + "geekai/service/mj" + "geekai/service/oss" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "gorm.io/gorm" +) + +type MidJourneyHandler struct { + BaseHandler + pool *mj.ServicePool + snowflake *service.Snowflake + uploader *oss.UploaderManager +} + +func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool, manager *oss.UploaderManager) *MidJourneyHandler { + return &MidJourneyHandler{ + snowflake: snowflake, + pool: pool, + uploader: manager, + BaseHandler: BaseHandler{ + App: app, + DB: db, + }, + } +} + +func (h *MidJourneyHandler) preCheck(c *gin.Context) bool { + user, err := h.GetLoginUser(c) + if err != nil { + resp.NotAuth(c) + return false + } + + if user.Power < h.App.SysConfig.MjPower { + resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!") + return false + } + + if !h.pool.HasAvailableService() { + resp.ERROR(c, "MidJourney 池子中没有没有可用的服务!") + return false + } + + return true + +} + +// Client WebSocket 客户端,用于通知任务状态变更 +func (h *MidJourneyHandler) Client(c *gin.Context) { + ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) + if err != nil { + logger.Error(err) + c.Abort() + return + } + + userId := h.GetInt(c, "user_id", 0) + if userId == 0 { + logger.Info("Invalid user ID") + c.Abort() + return + } + + client := types.NewWsClient(ws) + h.pool.Clients.Put(uint(userId), client) + logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) +} + +// Image 创建一个绘画任务 +func (h *MidJourneyHandler) Image(c *gin.Context) { + var data struct { + SessionId string `json:"session_id"` + TaskType string `json:"task_type"` + Prompt string `json:"prompt"` + NegPrompt string `json:"neg_prompt"` + Rate string `json:"rate"` + Model string `json:"model"` + Chaos int `json:"chaos"` + Raw bool `json:"raw"` + Seed int64 `json:"seed"` + Stylize int `json:"stylize"` + ImgArr []string `json:"img_arr"` + Tile bool `json:"tile"` + Quality float32 `json:"quality"` + Iw float32 `json:"iw"` + CRef string `json:"cref"` //生成角色一致的图像 + SRef string `json:"sref"` //生成风格一致的图像 + Cw int `json:"cw"` // 参考程度 + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + if !h.preCheck(c) { + return + } + + var params = "" + if data.Rate != "" && !strings.Contains(params, "--ar") { + params += " --ar " + data.Rate + } + if data.Seed > 0 && !strings.Contains(params, "--seed") { + params += fmt.Sprintf(" --seed %d", data.Seed) + } + if data.Stylize > 0 && !strings.Contains(params, "--s") && !strings.Contains(params, "--stylize") { + params += fmt.Sprintf(" --s %d", data.Stylize) + } + if data.Chaos > 0 && !strings.Contains(params, "--c") && !strings.Contains(params, "--chaos") { + params += fmt.Sprintf(" --c %d", data.Chaos) + } + if len(data.ImgArr) > 0 && data.Iw > 0 { + params += fmt.Sprintf(" --iw %.2f", data.Iw) + } + if data.Raw { + params += " --style raw" + } + if data.Quality > 0 { + params += fmt.Sprintf(" --q %.2f", data.Quality) + } + if data.Tile { + params += " --tile " + } + if data.CRef != "" { + params += fmt.Sprintf(" --cref %s", data.CRef) + if data.Cw > 0 { + params += fmt.Sprintf(" --cw %d", data.Cw) + } else { + params += " --cw 100" + } + } + + if data.SRef != "" { + params += fmt.Sprintf(" --sref %s", data.SRef) + } + if data.Model != "" && !strings.Contains(params, "--v") && !strings.Contains(params, "--niji") { + params += fmt.Sprintf(" %s", data.Model) + } + + // 处理融图和换脸的提示词 + if data.TaskType == types.TaskSwapFace.String() || data.TaskType == types.TaskBlend.String() { + params = fmt.Sprintf("%s:%s", data.TaskType, strings.Join(data.ImgArr, ",")) + } + + // 如果本地图片上传的是相对地址,处理成绝对地址 + for k, v := range data.ImgArr { + if !strings.HasPrefix(v, "http") { + data.ImgArr[k] = fmt.Sprintf("http://localhost:5678/%s", strings.TrimLeft(v, "/")) + } + } + + idValue, _ := c.Get(types.LoginUserID) + userId := utils.IntValue(utils.InterfaceToString(idValue), 0) + // generate task id + taskId, err := h.snowflake.Next(true) + if err != nil { + resp.ERROR(c, "error with generate task id: "+err.Error()) + return + } + job := model.MidJourneyJob{ + Type: data.TaskType, + UserId: userId, + TaskId: taskId, + Progress: 0, + Prompt: fmt.Sprintf("%s %s", data.Prompt, params), + Power: h.App.SysConfig.MjPower, + CreatedAt: time.Now(), + } + opt := "绘图" + if data.TaskType == types.TaskBlend.String() { + job.Prompt = "融图:" + strings.Join(data.ImgArr, ",") + opt = "融图" + } else if data.TaskType == types.TaskSwapFace.String() { + job.Prompt = "换脸:" + strings.Join(data.ImgArr, ",") + opt = "换脸" + } + + if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 { + resp.ERROR(c, "添加任务失败:"+res.Error.Error()) + return + } + + h.pool.PushTask(types.MjTask{ + Id: job.Id, + TaskId: taskId, + SessionId: data.SessionId, + Type: types.TaskType(data.TaskType), + Prompt: data.Prompt, + NegPrompt: data.NegPrompt, + Params: params, + UserId: userId, + ImgArr: data.ImgArr, + }) + + client := h.pool.Clients.Get(uint(job.UserId)) + if client != nil { + _ = client.Send([]byte("Task Updated")) + } + + // update user's power + tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power)) + // 记录算力变化日志 + if tx.Error == nil && tx.RowsAffected > 0 { + user, _ := h.GetLoginUser(c) + h.DB.Create(&model.PowerLog{ + UserId: user.Id, + Username: user.Username, + Type: types.PowerConsume, + Amount: job.Power, + Balance: user.Power - job.Power, + Mark: types.PowerSub, + Model: "mid-journey", + Remark: fmt.Sprintf("%s操作,任务ID:%s", opt, job.TaskId), + CreatedAt: time.Now(), + }) + } + resp.SUCCESS(c) +} + +type reqVo struct { + Index int `json:"index"` + ChannelId string `json:"channel_id"` + MessageId string `json:"message_id"` + MessageHash string `json:"message_hash"` + SessionId string `json:"session_id"` + Prompt string `json:"prompt"` + ChatId string `json:"chat_id"` + RoleId int `json:"role_id"` + Icon string `json:"icon"` +} + +// Upscale send upscale command to MidJourney Bot +func (h *MidJourneyHandler) Upscale(c *gin.Context) { + var data reqVo + if err := c.ShouldBindJSON(&data); err != nil || data.SessionId == "" { + resp.ERROR(c, types.InvalidArgs) + return + } + + if !h.preCheck(c) { + return + } + + idValue, _ := c.Get(types.LoginUserID) + userId := utils.IntValue(utils.InterfaceToString(idValue), 0) + taskId, _ := h.snowflake.Next(true) + job := model.MidJourneyJob{ + Type: types.TaskUpscale.String(), + ReferenceId: data.MessageId, + UserId: userId, + TaskId: taskId, + Progress: 0, + Prompt: data.Prompt, + Power: h.App.SysConfig.MjActionPower, + CreatedAt: time.Now(), + } + if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 { + resp.ERROR(c, "添加任务失败:"+res.Error.Error()) + return + } + + h.pool.PushTask(types.MjTask{ + Id: job.Id, + SessionId: data.SessionId, + Type: types.TaskUpscale, + Prompt: data.Prompt, + UserId: userId, + ChannelId: data.ChannelId, + Index: data.Index, + MessageId: data.MessageId, + MessageHash: data.MessageHash, + }) + + client := h.pool.Clients.Get(uint(job.UserId)) + if client != nil { + _ = client.Send([]byte("Task Updated")) + } + // update user's power + tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power)) + // 记录算力变化日志 + if tx.Error == nil && tx.RowsAffected > 0 { + user, _ := h.GetLoginUser(c) + h.DB.Create(&model.PowerLog{ + UserId: user.Id, + Username: user.Username, + Type: types.PowerConsume, + Amount: job.Power, + Balance: user.Power - job.Power, + Mark: types.PowerSub, + Model: "mid-journey", + Remark: fmt.Sprintf("Upscale 操作,任务ID:%s", job.TaskId), + CreatedAt: time.Now(), + }) + } + resp.SUCCESS(c) +} + +// Variation send variation command to MidJourney Bot +func (h *MidJourneyHandler) Variation(c *gin.Context) { + var data reqVo + if err := c.ShouldBindJSON(&data); err != nil || data.SessionId == "" { + resp.ERROR(c, types.InvalidArgs) + return + } + + if !h.preCheck(c) { + return + } + + idValue, _ := c.Get(types.LoginUserID) + userId := utils.IntValue(utils.InterfaceToString(idValue), 0) + taskId, _ := h.snowflake.Next(true) + job := model.MidJourneyJob{ + Type: types.TaskVariation.String(), + ChannelId: data.ChannelId, + ReferenceId: data.MessageId, + UserId: userId, + TaskId: taskId, + Progress: 0, + Prompt: data.Prompt, + Power: h.App.SysConfig.MjActionPower, + CreatedAt: time.Now(), + } + if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 { + resp.ERROR(c, "添加任务失败:"+res.Error.Error()) + return + } + + h.pool.PushTask(types.MjTask{ + Id: job.Id, + SessionId: data.SessionId, + Type: types.TaskVariation, + Prompt: data.Prompt, + UserId: userId, + Index: data.Index, + ChannelId: data.ChannelId, + MessageId: data.MessageId, + MessageHash: data.MessageHash, + }) + + client := h.pool.Clients.Get(uint(job.UserId)) + if client != nil { + _ = client.Send([]byte("Task Updated")) + } + + // update user's power + tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power)) + // 记录算力变化日志 + if tx.Error == nil && tx.RowsAffected > 0 { + user, _ := h.GetLoginUser(c) + h.DB.Create(&model.PowerLog{ + UserId: user.Id, + Username: user.Username, + Type: types.PowerConsume, + Amount: job.Power, + Balance: user.Power - job.Power, + Mark: types.PowerSub, + Model: "mid-journey", + Remark: fmt.Sprintf("Variation 操作,任务ID:%s", job.TaskId), + CreatedAt: time.Now(), + }) + } + resp.SUCCESS(c) +} + +// ImgWall 照片墙 +func (h *MidJourneyHandler) ImgWall(c *gin.Context) { + page := h.GetInt(c, "page", 0) + pageSize := h.GetInt(c, "page_size", 0) + err, jobs := h.getData(true, 0, page, pageSize, true) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c, jobs) +} + +// JobList 获取 MJ 任务列表 +func (h *MidJourneyHandler) JobList(c *gin.Context) { + status := h.GetBool(c, "status") + userId := h.GetLoginUserId(c) + page := h.GetInt(c, "page", 0) + pageSize := h.GetInt(c, "page_size", 0) + publish := h.GetBool(c, "publish") + + err, jobs := h.getData(status, userId, page, pageSize, publish) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c, jobs) +} + +// JobList 获取 MJ 任务列表 +func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.MidJourneyJob) { + session := h.DB.Session(&gorm.Session{}) + if finish { + session = session.Where("progress = ?", 100).Order("id DESC") + } else { + session = session.Where("progress < ?", 100).Order("id ASC") + } + if userId > 0 { + session = session.Where("user_id = ?", userId) + } + if publish { + session = session.Where("publish = ?", publish) + } + if page > 0 && pageSize > 0 { + offset := (page - 1) * pageSize + session = session.Offset(offset).Limit(pageSize) + } + + var items []model.MidJourneyJob + res := session.Find(&items) + if res.Error != nil { + return res.Error, nil + } + + var jobs = make([]vo.MidJourneyJob, 0) + for _, item := range items { + var job vo.MidJourneyJob + err := utils.CopyObject(item, &job) + if err != nil { + continue + } + + if item.Progress < 100 && item.ImgURL == "" && item.OrgURL != "" { + // discord 服务器图片需要使用代理转发图片数据流 + if strings.HasPrefix(item.OrgURL, "https://cdn.discordapp.com") { + image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL) + if err == nil { + job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image) + } + } else { + job.ImgURL = job.OrgURL + } + } + + jobs = append(jobs, job) + } + return nil, jobs +} + +// Remove remove task image +func (h *MidJourneyHandler) Remove(c *gin.Context) { + var data struct { + Id uint `json:"id"` + UserId uint `json:"user_id"` + ImgURL string `json:"img_url"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + // remove job recode + res := h.DB.Delete(&model.MidJourneyJob{Id: data.Id}) + if res.Error != nil { + resp.ERROR(c, res.Error.Error()) + return + } + + // remove image + err := h.uploader.GetUploadHandler().Delete(data.ImgURL) + if err != nil { + logger.Error("remove image failed: ", err) + } + + client := h.pool.Clients.Get(data.UserId) + if client != nil { + _ = client.Send([]byte("Task Updated")) + } + + resp.SUCCESS(c) +} + +// Publish 发布图片到画廊显示 +func (h *MidJourneyHandler) Publish(c *gin.Context) { + var data struct { + Id uint `json:"id"` + Action bool `json:"action"` // 发布动作,true => 发布,false => 取消分享 + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + res := h.DB.Model(&model.MidJourneyJob{Id: data.Id}).UpdateColumn("publish", data.Action) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败") + return + } + + resp.SUCCESS(c) +} diff --git a/handler/order_handler.go b/handler/order_handler.go new file mode 100644 index 0000000..a56daa8 --- /dev/null +++ b/handler/order_handler.go @@ -0,0 +1,62 @@ +package handler + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/core/types" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +type OrderHandler struct { + BaseHandler +} + +func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler { + return &OrderHandler{BaseHandler: BaseHandler{App: app, DB: db}} +} + +func (h *OrderHandler) List(c *gin.Context) { + var data struct { + Page int `json:"page"` + PageSize int `json:"page_size"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + userId := h.GetLoginUserId(c) + session := h.DB.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", userId, types.OrderPaidSuccess) + var total int64 + session.Model(&model.Order{}).Count(&total) + var items []model.Order + var list = make([]vo.Order, 0) + offset := (data.Page - 1) * data.PageSize + res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items) + if res.Error == nil { + for _, item := range items { + var order vo.Order + err := utils.CopyObject(item, &order) + if err == nil { + order.Id = item.Id + order.CreatedAt = item.CreatedAt.Unix() + order.UpdatedAt = item.UpdatedAt.Unix() + list = append(list, order) + } else { + logger.Error(err) + } + } + } + resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list)) +} diff --git a/handler/payment_handler.go b/handler/payment_handler.go new file mode 100644 index 0000000..9c18651 --- /dev/null +++ b/handler/payment_handler.go @@ -0,0 +1,613 @@ +package handler + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "embed" + "encoding/base64" + "fmt" + "geekai/core" + "geekai/core/types" + "geekai/service" + "geekai/service/payment" + "geekai/store/model" + "geekai/utils" + "geekai/utils/resp" + "github.com/shopspring/decimal" + "math" + "net/http" + "net/url" + "sync" + "time" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +const ( + PayWayAlipay = "支付宝" + PayWayXunHu = "虎皮椒" + PayWayJs = "PayJS" +) + +// PaymentHandler 支付服务回调 handler +type PaymentHandler struct { + BaseHandler + alipayService *payment.AlipayService + huPiPayService *payment.HuPiPayService + js *payment.PayJS + snowflake *service.Snowflake + fs embed.FS + lock sync.Mutex + signKey string // 用来签名的随机秘钥 +} + +func NewPaymentHandler( + server *core.AppServer, + alipayService *payment.AlipayService, + huPiPayService *payment.HuPiPayService, + js *payment.PayJS, + db *gorm.DB, + snowflake *service.Snowflake, + fs embed.FS) *PaymentHandler { + return &PaymentHandler{ + alipayService: alipayService, + huPiPayService: huPiPayService, + js: js, + snowflake: snowflake, + fs: fs, + lock: sync.Mutex{}, + BaseHandler: BaseHandler{ + App: server, + DB: db, + }, + signKey: utils.RandString(32), + } +} + +func (h *PaymentHandler) DoPay(c *gin.Context) { + orderNo := h.GetTrim(c, "order_no") + payWay := h.GetTrim(c, "pay_way") + t := h.GetInt(c, "t", 0) + sign := h.GetTrim(c, "sign") + signStr := fmt.Sprintf("%s-%s-%d-%s", orderNo, payWay, t, h.signKey) + newSign := utils.Sha256(signStr) + if newSign != sign { + resp.ERROR(c, "订单签名错误!") + return + } + + // 检查二维码是否过期 + if time.Now().Unix()-int64(t) > int64(h.App.SysConfig.OrderPayTimeout) { + resp.ERROR(c, "支付二维码已过期,请重新生成!") + return + } + + if orderNo == "" { + resp.ERROR(c, types.InvalidArgs) + return + } + + var order model.Order + res := h.DB.Where("order_no = ?", orderNo).First(&order) + if res.Error != nil { + resp.ERROR(c, "Order not found") + return + } + + // fix: 这里先检查一下订单状态,如果已经支付了,就直接返回 + if order.Status == types.OrderPaidSuccess { + resp.ERROR(c, "This order had been paid, please do not pay twice") + return + } + + // 更新扫码状态 + h.DB.Model(&order).UpdateColumn("status", types.OrderScanned) + if payWay == "alipay" { // 支付宝 + // 生成支付链接 + notifyURL := h.App.Config.AlipayConfig.NotifyURL + returnURL := "" // 关闭同步回跳 + amount := fmt.Sprintf("%.2f", order.Amount) + + uri, err := h.alipayService.PayUrlMobile(order.OrderNo, notifyURL, returnURL, amount, order.Subject) + if err != nil { + resp.ERROR(c, "error with generate pay url: "+err.Error()) + return + } + + c.Redirect(302, uri) + return + } else if payWay == "hupi" { // 虎皮椒支付 + params := payment.HuPiPayReq{ + Version: "1.1", + TradeOrderId: orderNo, + TotalFee: fmt.Sprintf("%f", order.Amount), + Title: order.Subject, + NotifyURL: h.App.Config.HuPiPayConfig.NotifyURL, + WapName: "极客学长", + } + r, err := h.huPiPayService.Pay(params) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + c.Redirect(302, r.URL) + } + resp.ERROR(c, "Invalid operations") +} + +// OrderQuery 查询订单状态 +func (h *PaymentHandler) OrderQuery(c *gin.Context) { + var data struct { + OrderNo string `json:"order_no"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + var order model.Order + res := h.DB.Where("order_no = ?", data.OrderNo).First(&order) + if res.Error != nil { + resp.ERROR(c, "Order not found") + return + } + + if order.Status == types.OrderPaidSuccess { + resp.SUCCESS(c, gin.H{"status": order.Status}) + return + } + + counter := 0 + for { + time.Sleep(time.Second) + var item model.Order + h.DB.Where("order_no = ?", data.OrderNo).First(&item) + if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status { + order.Status = item.Status + break + } + counter++ + } + + resp.SUCCESS(c, gin.H{"status": order.Status}) +} + +// PayQrcode 生成支付 URL 二维码 +func (h *PaymentHandler) PayQrcode(c *gin.Context) { + var data struct { + PayWay string `json:"pay_way"` // 支付方式 + ProductId uint `json:"product_id"` + UserId int `json:"user_id"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + var product model.Product + res := h.DB.First(&product, data.ProductId) + if res.Error != nil { + resp.ERROR(c, "Product not found") + return + } + + orderNo, err := h.snowflake.Next(false) + if err != nil { + resp.ERROR(c, "error with generate trade no: "+err.Error()) + return + } + var user model.User + res = h.DB.First(&user, data.UserId) + if res.Error != nil { + resp.ERROR(c, "Invalid user ID") + return + } + + var payWay string + var notifyURL string + switch data.PayWay { + case "hupi": + payWay = PayWayXunHu + notifyURL = h.App.Config.HuPiPayConfig.NotifyURL + case "payjs": + payWay = PayWayJs + notifyURL = h.App.Config.JPayConfig.NotifyURL + default: + payWay = PayWayAlipay + notifyURL = h.App.Config.AlipayConfig.NotifyURL + } + // 创建订单 + remark := types.OrderRemark{ + Days: product.Days, + Power: product.Power, + Name: product.Name, + Price: product.Price, + Discount: product.Discount, + } + + amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64() + order := model.Order{ + UserId: user.Id, + Username: user.Username, + ProductId: product.Id, + OrderNo: orderNo, + Subject: product.Name, + Amount: amount, + Status: types.OrderNotPaid, + PayWay: payWay, + Remark: utils.JsonEncode(remark), + } + res = h.DB.Create(&order) + if res.Error != nil || res.RowsAffected == 0 { + resp.ERROR(c, "error with create order: "+res.Error.Error()) + return + } + + // PayJs 单独处理,只能用官方生成的二维码 + if data.PayWay == "payjs" { + params := payment.JPayReq{ + TotalFee: int(math.Ceil(order.Amount * 100)), + OutTradeNo: order.OrderNo, + Subject: product.Name, + } + r := h.js.Pay(params) + if r.IsOK() { + resp.SUCCESS(c, gin.H{"order_no": order.OrderNo, "image": r.Qrcode}) + return + } else { + resp.ERROR(c, "error with generating payment qrcode: "+r.ReturnMsg) + return + } + } + + var logo string + if data.PayWay == "alipay" { + logo = "res/img/alipay.jpg" + } else if data.PayWay == "hupi" { + if h.App.Config.HuPiPayConfig.Name == "wechat" { + logo = "res/img/wechat-pay.jpg" + } else { + logo = "res/img/alipay.jpg" + } + } + + file, err := h.fs.Open(logo) + if err != nil { + resp.ERROR(c, "error with open qrcode log file: "+err.Error()) + return + } + + parse, err := url.Parse(notifyURL) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + timestamp := time.Now().Unix() + signStr := fmt.Sprintf("%s-%s-%d-%s", orderNo, data.PayWay, timestamp, h.signKey) + sign := utils.Sha256(signStr) + imageURL := fmt.Sprintf("%s://%s/api/payment/doPay?order_no=%s&pay_way=%s&t=%d&sign=%s", parse.Scheme, parse.Host, orderNo, data.PayWay, timestamp, sign) + imgData, err := utils.GenQrcode(imageURL, 400, file) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + imgDataBase64 := base64.StdEncoding.EncodeToString(imgData) + resp.SUCCESS(c, gin.H{"order_no": orderNo, "image": fmt.Sprintf("data:image/jpg;base64, %s", imgDataBase64), "url": imageURL}) +} + +// Mobile 移动端支付 +func (h *PaymentHandler) Mobile(c *gin.Context) { + var data struct { + PayWay string `json:"pay_way"` // 支付方式 + ProductId uint `json:"product_id"` + UserId int `json:"user_id"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + var product model.Product + res := h.DB.First(&product, data.ProductId) + if res.Error != nil { + resp.ERROR(c, "Product not found") + return + } + + orderNo, err := h.snowflake.Next(false) + if err != nil { + resp.ERROR(c, "error with generate trade no: "+err.Error()) + return + } + var user model.User + res = h.DB.First(&user, data.UserId) + if res.Error != nil { + resp.ERROR(c, "Invalid user ID") + return + } + + amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64() + var payWay string + var notifyURL, returnURL string + var payURL string + switch data.PayWay { + case "hupi": + payWay = PayWayXunHu + notifyURL = h.App.Config.HuPiPayConfig.NotifyURL + returnURL = h.App.Config.HuPiPayConfig.ReturnURL + parse, _ := url.Parse(h.App.Config.HuPiPayConfig.ReturnURL) + baseURL := fmt.Sprintf("%s://%s", parse.Scheme, parse.Host) + params := payment.HuPiPayReq{ + Version: "1.1", + TradeOrderId: orderNo, + TotalFee: fmt.Sprintf("%f", amount), + Title: product.Name, + NotifyURL: notifyURL, + ReturnURL: returnURL, + CallbackURL: returnURL, + WapName: "极客学长", + WapUrl: baseURL, + Type: "WAP", + } + r, err := h.huPiPayService.Pay(params) + if err != nil { + logger.Error("error with generating Pay URL: ", err.Error()) + resp.ERROR(c, "error with generating Pay URL: "+err.Error()) + return + } + payURL = r.URL + case "payjs": + payWay = PayWayJs + notifyURL = h.App.Config.JPayConfig.NotifyURL + returnURL = h.App.Config.JPayConfig.ReturnURL + totalFee := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Mul(decimal.NewFromInt(100)).IntPart() + params := url.Values{} + params.Add("total_fee", fmt.Sprintf("%d", totalFee)) + params.Add("out_trade_no", orderNo) + params.Add("body", product.Name) + params.Add("notify_url", notifyURL) + params.Add("auto", "0") + payURL = h.js.PayH5(params) + case "alipay": + payWay = PayWayAlipay + notifyURL = h.App.Config.AlipayConfig.NotifyURL + returnURL = h.App.Config.AlipayConfig.ReturnURL + payURL, err = h.alipayService.PayUrlMobile(orderNo, notifyURL, returnURL, fmt.Sprintf("%.2f", amount), product.Name) + if err != nil { + resp.ERROR(c, "error with generating Pay URL: "+err.Error()) + return + } + default: + resp.ERROR(c, "Unsupported pay way: "+data.PayWay) + return + } + // 创建订单 + remark := types.OrderRemark{ + Days: product.Days, + Power: product.Power, + Name: product.Name, + Price: product.Price, + Discount: product.Discount, + } + + order := model.Order{ + UserId: user.Id, + Username: user.Username, + ProductId: product.Id, + OrderNo: orderNo, + Subject: product.Name, + Amount: amount, + Status: types.OrderNotPaid, + PayWay: payWay, + Remark: utils.JsonEncode(remark), + } + res = h.DB.Create(&order) + if res.Error != nil || res.RowsAffected == 0 { + resp.ERROR(c, "error with create order: "+res.Error.Error()) + return + } + + resp.SUCCESS(c, payURL) +} + +// 异步通知回调公共逻辑 +func (h *PaymentHandler) notify(orderNo string, tradeNo string) error { + var order model.Order + res := h.DB.Where("order_no = ?", orderNo).First(&order) + if res.Error != nil { + err := fmt.Errorf("error with fetch order: %v", res.Error) + logger.Error(err) + return err + } + + h.lock.Lock() + defer h.lock.Unlock() + + // 已支付订单,直接返回 + if order.Status == types.OrderPaidSuccess { + return nil + } + + var user model.User + res = h.DB.First(&user, order.UserId) + if res.Error != nil { + err := fmt.Errorf("error with fetch user info: %v", res.Error) + logger.Error(err) + return err + } + + var remark types.OrderRemark + err := utils.JsonDecode(order.Remark, &remark) + if err != nil { + err := fmt.Errorf("error with decode order remark: %v", err) + logger.Error(err) + return err + } + + var opt string + var power int + if remark.Days > 0 { // VIP 充值 + if user.ExpiredTime >= time.Now().Unix() { + user.ExpiredTime = time.Unix(user.ExpiredTime, 0).AddDate(0, 0, remark.Days).Unix() + opt = "VIP充值,VIP 没到期,只延期不增加算力" + } else { + user.ExpiredTime = time.Now().AddDate(0, 0, remark.Days).Unix() + user.Power += h.App.SysConfig.VipMonthPower + power = h.App.SysConfig.VipMonthPower + opt = "VIP充值" + } + user.Vip = true + } else { // 充值点卡,直接增加次数即可 + user.Power += remark.Power + opt = "点卡充值" + power = remark.Power + } + + // 更新用户信息 + res = h.DB.Updates(&user) + if res.Error != nil { + err := fmt.Errorf("error with update user info: %v", res.Error) + logger.Error(err) + return err + } + + // 更新订单状态 + order.PayTime = time.Now().Unix() + order.Status = types.OrderPaidSuccess + order.TradeNo = tradeNo + res = h.DB.Updates(&order) + if res.Error != nil { + err := fmt.Errorf("error with update order info: %v", res.Error) + logger.Error(err) + return err + } + + // 更新产品销量 + h.DB.Model(&model.Product{}).Where("id = ?", order.ProductId).UpdateColumn("sales", gorm.Expr("sales + ?", 1)) + + // 记录算力充值日志 + if opt != "" { + h.DB.Create(&model.PowerLog{ + UserId: user.Id, + Username: user.Username, + Type: types.PowerRecharge, + Amount: power, + Balance: user.Power, + Mark: types.PowerAdd, + Model: order.PayWay, + Remark: fmt.Sprintf("%s,金额:%f,订单号:%s", opt, order.Amount, order.OrderNo), + CreatedAt: time.Now(), + }) + } + + return nil +} + +// GetPayWays 获取支付方式 +func (h *PaymentHandler) GetPayWays(c *gin.Context) { + data := gin.H{} + if h.App.Config.AlipayConfig.Enabled { + data["alipay"] = gin.H{"name": "alipay"} + } + if h.App.Config.HuPiPayConfig.Enabled { + data["hupi"] = gin.H{"name": h.App.Config.HuPiPayConfig.Name} + } + if h.App.Config.JPayConfig.Enabled { + data["payjs"] = gin.H{"name": h.App.Config.JPayConfig.Name} + } + resp.SUCCESS(c, data) +} + +// HuPiPayNotify 虎皮椒支付异步回调 +func (h *PaymentHandler) HuPiPayNotify(c *gin.Context) { + err := c.Request.ParseForm() + if err != nil { + c.String(http.StatusOK, "fail") + return + } + + orderNo := c.Request.Form.Get("trade_order_id") + tradeNo := c.Request.Form.Get("open_order_id") + logger.Infof("收到虎皮椒订单支付回调,订单 NO:%s,交易流水号:%s", orderNo, tradeNo) + + if err = h.huPiPayService.Check(tradeNo); err != nil { + logger.Error("订单校验失败:", err) + c.String(http.StatusOK, "fail") + return + } + err = h.notify(orderNo, tradeNo) + if err != nil { + c.String(http.StatusOK, "fail") + return + } + + c.String(http.StatusOK, "success") +} + +// AlipayNotify 支付宝支付回调 +func (h *PaymentHandler) AlipayNotify(c *gin.Context) { + err := c.Request.ParseForm() + if err != nil { + c.String(http.StatusOK, "fail") + return + } + + // TODO:验证交易签名 + res := h.alipayService.TradeVerify(c.Request.Form) + logger.Infof("验证支付结果:%+v", res) + if !res.Success() { + logger.Error("订单校验失败:", res.Message) + c.String(http.StatusOK, "fail") + return + } + + tradeNo := c.Request.Form.Get("trade_no") + err = h.notify(res.OutTradeNo, tradeNo) + if err != nil { + c.String(http.StatusOK, "fail") + return + } + + c.String(http.StatusOK, "success") +} + +// PayJsNotify PayJs 支付异步回调 +func (h *PaymentHandler) PayJsNotify(c *gin.Context) { + err := c.Request.ParseForm() + if err != nil { + c.String(http.StatusOK, "fail") + return + } + + orderNo := c.Request.Form.Get("out_trade_no") + returnCode := c.Request.Form.Get("return_code") + logger.Infof("收到订单支付回调,订单 NO:%s,支付结果代码:%v", orderNo, returnCode) + // 支付失败 + if returnCode != "1" { + return + } + + // 校验订单支付状态 + tradeNo := c.Request.Form.Get("payjs_order_id") + err = h.js.Check(tradeNo) + if err != nil { + logger.Error("订单校验失败:", err) + c.String(http.StatusOK, "fail") + return + } + + err = h.notify(orderNo, tradeNo) + if err != nil { + c.String(http.StatusOK, "fail") + return + } + + c.String(http.StatusOK, "success") +} diff --git a/handler/power_log_handler.go b/handler/power_log_handler.go new file mode 100644 index 0000000..7773221 --- /dev/null +++ b/handler/power_log_handler.go @@ -0,0 +1,74 @@ +package handler + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/core/types" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +type PowerLogHandler struct { + BaseHandler +} + +func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler { + return &PowerLogHandler{BaseHandler: BaseHandler{App: app, DB: db}} +} + +func (h *PowerLogHandler) List(c *gin.Context) { + var data struct { + Model string `json:"model"` + Date []string `json:"date"` + Page int `json:"page"` + PageSize int `json:"page_size"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + session := h.DB.Session(&gorm.Session{}) + userId := h.GetLoginUserId(c) + session = session.Where("user_id", userId) + if data.Model != "" { + session = session.Where("model", data.Model) + } + if len(data.Date) == 2 { + start := data.Date[0] + " 00:00:00" + end := data.Date[1] + " 00:00:00" + session = session.Where("created_at >= ? AND created_at <= ?", start, end) + } + + var total int64 + session.Model(&model.PowerLog{}).Count(&total) + var items []model.PowerLog + var list = make([]vo.PowerLog, 0) + offset := (data.Page - 1) * data.PageSize + res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items) + if res.Error == nil { + for _, item := range items { + var log vo.PowerLog + err := utils.CopyObject(item, &log) + if err != nil { + continue + } + log.Id = item.Id + log.CreatedAt = item.CreatedAt.Unix() + log.TypeStr = item.Type.String() + list = append(list, log) + } + } + resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list)) +} diff --git a/handler/product_handler.go b/handler/product_handler.go new file mode 100644 index 0000000..34959aa --- /dev/null +++ b/handler/product_handler.go @@ -0,0 +1,48 @@ +package handler + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +type ProductHandler struct { + BaseHandler +} + +func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler { + return &ProductHandler{BaseHandler: BaseHandler{App: app, DB: db}} +} + +// List 模型列表 +func (h *ProductHandler) List(c *gin.Context) { + var items []model.Product + var list = make([]vo.Product, 0) + res := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items) + if res.Error == nil { + for _, item := range items { + var product vo.Product + err := utils.CopyObject(item, &product) + if err == nil { + product.Id = item.Id + product.CreatedAt = item.CreatedAt.Unix() + product.UpdatedAt = item.UpdatedAt.Unix() + list = append(list, product) + } else { + logger.Error(err) + } + } + } + resp.SUCCESS(c, list) +} diff --git a/handler/reward_handler.go b/handler/reward_handler.go new file mode 100644 index 0000000..44045f5 --- /dev/null +++ b/handler/reward_handler.go @@ -0,0 +1,108 @@ +package handler + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "fmt" + "geekai/core" + "geekai/core/types" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + "github.com/gin-gonic/gin" + "gorm.io/gorm" + "math" + "strings" + "sync" + "time" +) + +type RewardHandler struct { + BaseHandler + lock sync.Mutex +} + +func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler { + return &RewardHandler{BaseHandler: BaseHandler{App: app, DB: db}} +} + +// Verify 打赏码核销 +func (h *RewardHandler) Verify(c *gin.Context) { + var data struct { + TxId string `json:"tx_id"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + user, err := h.GetLoginUser(c) + if err != nil { + resp.HACKER(c) + return + } + + // 移除转账单号中间的空格,防止有人复制的时候多复制了空格 + data.TxId = strings.ReplaceAll(data.TxId, " ", "") + + h.lock.Lock() + defer h.lock.Unlock() + + var item model.Reward + res := h.DB.Where("tx_id = ?", data.TxId).First(&item) + if res.Error != nil { + resp.ERROR(c, "无效的众筹交易流水号!") + return + } + + if item.Status { + resp.ERROR(c, "当前众筹交易流水号已经被核销,请不要重复核销!") + return + } + + tx := h.DB.Begin() + exchange := vo.RewardExchange{} + power := math.Ceil(item.Amount / h.App.SysConfig.PowerPrice) + exchange.Power = int(power) + res = tx.Model(&user).UpdateColumn("power", gorm.Expr("power + ?", exchange.Power)) + if res.Error != nil { + tx.Rollback() + logger.Error("添加应用失败:", res.Error) + resp.ERROR(c, "更新数据库失败!") + return + } + + // 更新核销状态 + item.Status = true + item.UserId = user.Id + item.Exchange = utils.JsonEncode(exchange) + res = tx.Updates(&item) + if res.Error != nil { + tx.Rollback() + logger.Error("添加应用失败:", res.Error) + resp.ERROR(c, "更新数据库失败!") + return + } + + // 记录算力充值日志 + h.DB.Create(&model.PowerLog{ + UserId: user.Id, + Username: user.Username, + Type: types.PowerReward, + Amount: exchange.Power, + Balance: user.Power + exchange.Power, + Mark: types.PowerAdd, + Model: "众筹支付", + Remark: fmt.Sprintf("众筹充值算力,金额:%f,价格:%f", item.Amount, h.App.SysConfig.PowerPrice), + CreatedAt: time.Now(), + }) + tx.Commit() + resp.SUCCESS(c) + +} diff --git a/handler/sd_handler.go b/handler/sd_handler.go new file mode 100644 index 0000000..e30e837 --- /dev/null +++ b/handler/sd_handler.go @@ -0,0 +1,334 @@ +package handler + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "fmt" + "geekai/core" + "geekai/core/types" + "geekai/service" + "geekai/service/oss" + "geekai/service/sd" + "geekai/store" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + "net/http" + "time" + + "github.com/gorilla/websocket" + + "github.com/gin-gonic/gin" + "github.com/go-redis/redis/v8" + "gorm.io/gorm" +) + +type SdJobHandler struct { + BaseHandler + redis *redis.Client + pool *sd.ServicePool + uploader *oss.UploaderManager + snowflake *service.Snowflake + leveldb *store.LevelDB +} + +func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager, snowflake *service.Snowflake, levelDB *store.LevelDB) *SdJobHandler { + return &SdJobHandler{ + pool: pool, + uploader: manager, + snowflake: snowflake, + leveldb: levelDB, + BaseHandler: BaseHandler{ + App: app, + DB: db, + }, + } +} + +// Client WebSocket 客户端,用于通知任务状态变更 +func (h *SdJobHandler) Client(c *gin.Context) { + ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) + if err != nil { + logger.Error(err) + c.Abort() + return + } + + userId := h.GetInt(c, "user_id", 0) + if userId == 0 { + logger.Info("Invalid user ID") + c.Abort() + return + } + + client := types.NewWsClient(ws) + h.pool.Clients.Put(uint(userId), client) + logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) +} + +func (h *SdJobHandler) preCheck(c *gin.Context) bool { + user, err := h.GetLoginUser(c) + if err != nil { + resp.NotAuth(c) + return false + } + + if !h.pool.HasAvailableService() { + resp.ERROR(c, "Stable-Diffusion 池子中没有没有可用的服务!") + return false + } + + if user.Power < h.App.SysConfig.SdPower { + resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!") + return false + } + + return true + +} + +// Image 创建一个绘画任务 +func (h *SdJobHandler) Image(c *gin.Context) { + if !h.preCheck(c) { + return + } + + var data struct { + SessionId string `json:"session_id"` + types.SdTaskParams + } + if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" { + resp.ERROR(c, types.InvalidArgs) + return + } + + if data.Width <= 0 { + data.Width = 512 + } + if data.Height <= 0 { + data.Height = 512 + } + if data.CfgScale <= 0 { + data.CfgScale = 7 + } + if data.Seed == 0 { + data.Seed = -1 + } + if data.Steps <= 0 { + data.Steps = 20 + } + if data.Sampler == "" { + data.Sampler = "Euler a" + } + idValue, _ := c.Get(types.LoginUserID) + userId := utils.IntValue(utils.InterfaceToString(idValue), 0) + taskId, err := h.snowflake.Next(true) + if err != nil { + resp.ERROR(c, "error with generate task id: "+err.Error()) + return + } + params := types.SdTaskParams{ + TaskId: taskId, + Prompt: data.Prompt, + NegPrompt: data.NegPrompt, + Steps: data.Steps, + Sampler: data.Sampler, + FaceFix: data.FaceFix, + CfgScale: data.CfgScale, + Seed: data.Seed, + Height: data.Height, + Width: data.Width, + HdFix: data.HdFix, + HdRedrawRate: data.HdRedrawRate, + HdScale: data.HdScale, + HdScaleAlg: data.HdScaleAlg, + HdSteps: data.HdSteps, + } + + job := model.SdJob{ + UserId: userId, + Type: types.TaskImage.String(), + TaskId: params.TaskId, + Params: utils.JsonEncode(params), + Prompt: data.Prompt, + Progress: 0, + Power: h.App.SysConfig.SdPower, + CreatedAt: time.Now(), + } + res := h.DB.Create(&job) + if res.Error != nil { + resp.ERROR(c, "error with save job: "+res.Error.Error()) + return + } + + h.pool.PushTask(types.SdTask{ + Id: int(job.Id), + SessionId: data.SessionId, + Type: types.TaskImage, + Params: params, + UserId: userId, + }) + + client := h.pool.Clients.Get(uint(job.UserId)) + if client != nil { + _ = client.Send([]byte("Task Updated")) + } + + // update user's power + tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power)) + // 记录算力变化日志 + if tx.Error == nil && tx.RowsAffected > 0 { + user, _ := h.GetLoginUser(c) + h.DB.Create(&model.PowerLog{ + UserId: user.Id, + Username: user.Username, + Type: types.PowerConsume, + Amount: job.Power, + Balance: user.Power - job.Power, + Mark: types.PowerSub, + Model: "stable-diffusion", + Remark: fmt.Sprintf("绘图操作,任务ID:%s", job.TaskId), + CreatedAt: time.Now(), + }) + } + + resp.SUCCESS(c) +} + +// ImgWall 照片墙 +func (h *SdJobHandler) ImgWall(c *gin.Context) { + page := h.GetInt(c, "page", 0) + pageSize := h.GetInt(c, "page_size", 0) + err, jobs := h.getData(true, 0, page, pageSize, true) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c, jobs) +} + +// JobList 获取 SD 任务列表 +func (h *SdJobHandler) JobList(c *gin.Context) { + status := h.GetBool(c, "status") + userId := h.GetLoginUserId(c) + page := h.GetInt(c, "page", 0) + pageSize := h.GetInt(c, "page_size", 0) + publish := h.GetBool(c, "publish") + + err, jobs := h.getData(status, userId, page, pageSize, publish) + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + resp.SUCCESS(c, jobs) +} + +// JobList 获取 MJ 任务列表 +func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.SdJob) { + + session := h.DB.Session(&gorm.Session{}) + if finish { + session = session.Where("progress = ?", 100).Order("id DESC") + } else { + session = session.Where("progress < ?", 100).Order("id ASC") + } + if userId > 0 { + session = session.Where("user_id = ?", userId) + } + if publish { + session = session.Where("publish", publish) + } + if page > 0 && pageSize > 0 { + offset := (page - 1) * pageSize + session = session.Offset(offset).Limit(pageSize) + } + + var items []model.SdJob + res := session.Find(&items) + if res.Error != nil { + return res.Error, nil + } + + var jobs = make([]vo.SdJob, 0) + for _, item := range items { + var job vo.SdJob + err := utils.CopyObject(item, &job) + if err != nil { + continue + } + + if item.Progress < 100 { + // 从 leveldb 中获取图片预览数据 + var imageData string + err = h.leveldb.Get(item.TaskId, &imageData) + if err == nil { + job.ImgURL = "data:image/png;base64," + imageData + } + } + jobs = append(jobs, job) + } + + return nil, jobs +} + +// Remove remove task image +func (h *SdJobHandler) Remove(c *gin.Context) { + var data struct { + Id uint `json:"id"` + UserId uint `json:"user_id"` + ImgURL string `json:"img_url"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + // remove job recode + res := h.DB.Delete(&model.SdJob{Id: data.Id}) + if res.Error != nil { + resp.ERROR(c, res.Error.Error()) + return + } + + // remove image + err := h.uploader.GetUploadHandler().Delete(data.ImgURL) + if err != nil { + logger.Error("remove image failed: ", err) + } + + client := h.pool.Clients.Get(data.UserId) + if client != nil { + _ = client.Send([]byte(sd.Finished)) + } + + resp.SUCCESS(c) +} + +// Publish 发布/取消发布图片到画廊显示 +func (h *SdJobHandler) Publish(c *gin.Context) { + var data struct { + Id uint `json:"id"` + Action bool `json:"action"` // 发布动作,true => 发布,false => 取消分享 + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + res := h.DB.Model(&model.SdJob{Id: data.Id}).UpdateColumn("publish", true) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败") + return + } + + resp.SUCCESS(c) +} diff --git a/handler/sms_handler.go b/handler/sms_handler.go new file mode 100644 index 0000000..7740a77 --- /dev/null +++ b/handler/sms_handler.go @@ -0,0 +1,93 @@ +package handler + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/core/types" + "geekai/service" + "geekai/service/sms" + "geekai/utils" + "geekai/utils/resp" + "strings" + + "github.com/gin-gonic/gin" + "github.com/go-redis/redis/v8" +) + +const CodeStorePrefix = "/verify/codes/" + +type SmsHandler struct { + BaseHandler + redis *redis.Client + sms *sms.ServiceManager + smtp *service.SmtpService + captcha *service.CaptchaService +} + +func NewSmsHandler( + app *core.AppServer, + client *redis.Client, + sms *sms.ServiceManager, + smtp *service.SmtpService, + captcha *service.CaptchaService) *SmsHandler { + return &SmsHandler{ + redis: client, + sms: sms, + captcha: captcha, + smtp: smtp, + BaseHandler: BaseHandler{App: app}} +} + +// SendCode 发送验证码 +func (h *SmsHandler) SendCode(c *gin.Context) { + var data struct { + Receiver string `json:"receiver"` // 接收者 + Key string `json:"key"` + Dots string `json:"dots"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + if !h.captcha.Check(data) { + resp.ERROR(c, "验证码错误,请先完人机验证") + return + } + + code := utils.RandomNumber(6) + var err error + if strings.Contains(data.Receiver, "@") { // email + if !utils.ContainsStr(h.App.SysConfig.RegisterWays, "email") { + resp.ERROR(c, "系统已禁用邮箱注册!") + return + } + err = h.smtp.SendVerifyCode(data.Receiver, code) + } else { + if !utils.ContainsStr(h.App.SysConfig.RegisterWays, "mobile") { + resp.ERROR(c, "系统已禁用手机号注册!") + return + } + err = h.sms.GetService().SendVerifyCode(data.Receiver, code) + + } + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + // 存储验证码,等待后面注册验证 + _, err = h.redis.Set(c, CodeStorePrefix+data.Receiver, code, 0).Result() + if err != nil { + resp.ERROR(c, "验证码保存失败") + return + } + + resp.SUCCESS(c) +} diff --git a/handler/test_handler.go b/handler/test_handler.go new file mode 100644 index 0000000..35aba79 --- /dev/null +++ b/handler/test_handler.go @@ -0,0 +1,17 @@ +package handler + +import ( + "geekai/service" + "geekai/service/payment" + "gorm.io/gorm" +) + +type TestHandler struct { + db *gorm.DB + snowflake *service.Snowflake + js *payment.PayJS +} + +func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.PayJS) *TestHandler { + return &TestHandler{db: db, snowflake: snowflake, js: js} +} diff --git a/handler/upload_handler.go b/handler/upload_handler.go new file mode 100644 index 0000000..af37f97 --- /dev/null +++ b/handler/upload_handler.go @@ -0,0 +1,101 @@ +package handler + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core" + "geekai/service/oss" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + "github.com/gin-gonic/gin" + "gorm.io/gorm" + "time" +) + +type UploadHandler struct { + BaseHandler + uploaderManager *oss.UploaderManager +} + +func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler { + return &UploadHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager} +} + +func (h *UploadHandler) Upload(c *gin.Context) { + file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file") + if err != nil { + resp.ERROR(c, err.Error()) + return + } + + userId := h.GetLoginUserId(c) + res := h.DB.Create(&model.File{ + UserId: int(userId), + Name: file.Name, + ObjKey: file.ObjKey, + URL: file.URL, + Ext: file.Ext, + Size: file.Size, + CreatedAt: time.Time{}, + }) + if res.Error != nil || res.RowsAffected == 0 { + resp.ERROR(c, "error with update database: "+res.Error.Error()) + return + } + + resp.SUCCESS(c, file) +} + +func (h *UploadHandler) List(c *gin.Context) { + userId := h.GetLoginUserId(c) + var items []model.File + var files = make([]vo.File, 0) + h.DB.Where("user_id = ?", userId).Find(&items) + if len(items) > 0 { + for _, v := range items { + var file vo.File + err := utils.CopyObject(v, &file) + if err != nil { + logger.Error(err) + continue + } + file.CreatedAt = v.CreatedAt.Unix() + files = append(files, file) + } + } + + resp.SUCCESS(c, files) +} + +// Remove remove files +func (h *UploadHandler) Remove(c *gin.Context) { + userId := h.GetLoginUserId(c) + id := h.GetInt(c, "id", 0) + var file model.File + tx := h.DB.Where("user_id = ? AND id = ?", userId, id).First(&file) + if tx.Error != nil || file.Id == 0 { + resp.ERROR(c, "file not existed") + return + } + + // remove database + tx = h.DB.Model(&model.File{}).Delete("id = ?", id) + if tx.Error != nil || tx.RowsAffected == 0 { + resp.ERROR(c, "failed to update database") + return + } + // remove files + objectKey := file.ObjKey + if objectKey == "" { + objectKey = file.URL + } + _ = h.uploaderManager.GetUploadHandler().Delete(objectKey) + resp.SUCCESS(c) +} diff --git a/handler/user_handler.go b/handler/user_handler.go new file mode 100644 index 0000000..2718d15 --- /dev/null +++ b/handler/user_handler.go @@ -0,0 +1,440 @@ +package handler + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "fmt" + "geekai/core" + "geekai/core/types" + "geekai/service" + "geekai/store/model" + "geekai/store/vo" + "geekai/utils" + "geekai/utils/resp" + "strings" + "time" + + "github.com/go-redis/redis/v8" + "github.com/golang-jwt/jwt/v5" + + "github.com/gin-gonic/gin" + "github.com/lionsoul2014/ip2region/binding/golang/xdb" + "gorm.io/gorm" +) + +type UserHandler struct { + BaseHandler + searcher *xdb.Searcher + redis *redis.Client + licenseService *service.LicenseService +} + +func NewUserHandler( + app *core.AppServer, + db *gorm.DB, + searcher *xdb.Searcher, + client *redis.Client, + licenseService *service.LicenseService) *UserHandler { + return &UserHandler{ + BaseHandler: BaseHandler{DB: db, App: app}, + searcher: searcher, + redis: client, + licenseService: licenseService, + } +} + +// Register user register +func (h *UserHandler) Register(c *gin.Context) { + // parameters process + var data struct { + RegWay string `json:"reg_way"` + Username string `json:"username"` + Password string `json:"password"` + Code string `json:"code"` + InviteCode string `json:"invite_code"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + data.Password = strings.TrimSpace(data.Password) + if len(data.Password) < 8 { + resp.ERROR(c, "密码长度不能少于8个字符") + return + } + + // 检测最大注册人数 + var totalUser int64 + h.DB.Model(&model.User{}).Count(&totalUser) + if h.licenseService.GetLicense().Configs.UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().Configs.UserNum { + resp.ERROR(c, "当前注册用户数已达上限,请请升级 License") + return + } + + // 检查验证码 + var key string + if data.RegWay == "email" || data.RegWay == "mobile" { + key = CodeStorePrefix + data.Username + code, err := h.redis.Get(c, key).Result() + if err != nil || code != data.Code { + resp.ERROR(c, "验证码错误") + return + } + } + + // 验证邀请码 + inviteCode := model.InviteCode{} + if data.InviteCode != "" { + res := h.DB.Where("code = ?", data.InviteCode).First(&inviteCode) + if res.Error != nil { + resp.ERROR(c, "无效的邀请码") + return + } + } + + // check if the username is exists + var item model.User + res := h.DB.Where("username = ?", data.Username).First(&item) + if item.Id > 0 { + resp.ERROR(c, "该用户名已经被注册") + return + } + + salt := utils.RandString(8) + user := model.User{ + Username: data.Username, + Password: utils.GenPassword(data.Password, salt), + Nickname: fmt.Sprintf("极客学长@%d", utils.RandomNumber(6)), + Avatar: "/images/avatar/user.png", + Salt: salt, + Status: true, + ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色 + ChatModels: utils.JsonEncode(h.App.SysConfig.DefaultModels), // 默认开通的模型 + Power: h.App.SysConfig.InitPower, + } + + res = h.DB.Create(&user) + if res.Error != nil { + resp.ERROR(c, "保存数据失败") + logger.Error(res.Error) + return + } + + // 记录邀请关系 + if data.InviteCode != "" { + // 增加邀请数量 + h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1)) + if h.App.SysConfig.InvitePower > 0 { + h.DB.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("power", gorm.Expr("power + ?", h.App.SysConfig.InvitePower)) + // 记录邀请算力充值日志 + var inviter model.User + h.DB.Where("id", inviteCode.UserId).First(&inviter) + h.DB.Create(&model.PowerLog{ + UserId: inviter.Id, + Username: inviter.Username, + Type: types.PowerInvite, + Amount: h.App.SysConfig.InvitePower, + Balance: inviter.Power, + Mark: types.PowerAdd, + Model: "", + Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d,邀请码:%s,新用户:%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username), + CreatedAt: time.Now(), + }) + } + + // 添加邀请记录 + h.DB.Create(&model.InviteLog{ + InviterId: inviteCode.UserId, + UserId: user.Id, + Username: user.Username, + InviteCode: inviteCode.Code, + Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower), + }) + } + + _ = h.redis.Del(c, key) // 注册成功,删除短信验证码 + + // 自动登录创建 token + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "user_id": user.Id, + "expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(), + }) + tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey)) + if err != nil { + resp.ERROR(c, "Failed to generate token, "+err.Error()) + return + } + // 保存到 redis + key = fmt.Sprintf("users/%d", user.Id) + if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil { + resp.ERROR(c, "error with save token: "+err.Error()) + return + } + resp.SUCCESS(c, tokenString) +} + +// Login 用户登录 +func (h *UserHandler) Login(c *gin.Context) { + var data struct { + Username string `json:"username"` + Password string `json:"password"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + var user model.User + res := h.DB.Where("username = ?", data.Username).First(&user) + if res.Error != nil { + resp.ERROR(c, "用户名不存在") + return + } + + password := utils.GenPassword(data.Password, user.Salt) + if password != user.Password { + resp.ERROR(c, "用户名或密码错误") + return + } + + if user.Status == false { + resp.ERROR(c, "该用户已被禁止登录,请联系管理员") + return + } + + // 更新最后登录时间和IP + user.LastLoginIp = c.ClientIP() + user.LastLoginAt = time.Now().Unix() + h.DB.Model(&user).Updates(user) + + h.DB.Create(&model.UserLoginLog{ + UserId: user.Id, + Username: user.Username, + LoginIp: c.ClientIP(), + LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()), + }) + + // 创建 token + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "user_id": user.Id, + "expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(), + }) + tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey)) + if err != nil { + resp.ERROR(c, "Failed to generate token, "+err.Error()) + return + } + // 保存到 redis + key := fmt.Sprintf("users/%d", user.Id) + if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil { + resp.ERROR(c, "error with save token: "+err.Error()) + return + } + resp.SUCCESS(c, tokenString) +} + +// Logout 注 销 +func (h *UserHandler) Logout(c *gin.Context) { + key := h.GetUserKey(c) + if _, err := h.redis.Del(c, key).Result(); err != nil { + logger.Error("error with delete session: ", err) + } + resp.SUCCESS(c) +} + +// Session 获取/验证会话 +func (h *UserHandler) Session(c *gin.Context) { + user, err := h.GetLoginUser(c) + if err == nil { + var userVo vo.User + err := utils.CopyObject(user, &userVo) + if err != nil { + resp.ERROR(c) + } + userVo.Id = user.Id + resp.SUCCESS(c, userVo) + } else { + resp.NotAuth(c) + } + +} + +type userProfile struct { + Id uint `json:"id"` + Nickname string `json:"nickname"` + Username string `json:"username"` + Avatar string `json:"avatar"` + Power int `json:"power"` + ExpiredTime int64 `json:"expired_time"` + Vip bool `json:"vip"` +} + +func (h *UserHandler) Profile(c *gin.Context) { + user, err := h.GetLoginUser(c) + if err != nil { + resp.NotAuth(c) + return + } + + h.DB.First(&user, user.Id) + var profile userProfile + err = utils.CopyObject(user, &profile) + if err != nil { + logger.Error("对象拷贝失败:", err.Error()) + resp.ERROR(c, "获取用户信息失败") + return + } + + profile.Id = user.Id + resp.SUCCESS(c, profile) +} + +func (h *UserHandler) ProfileUpdate(c *gin.Context) { + var data userProfile + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + user, err := h.GetLoginUser(c) + if err != nil { + resp.NotAuth(c) + return + } + h.DB.First(&user, user.Id) + user.Avatar = data.Avatar + user.Nickname = data.Nickname + res := h.DB.Updates(&user) + if res.Error != nil { + resp.ERROR(c, "更新用户信息失败") + return + } + + resp.SUCCESS(c) +} + +// UpdatePass 更新密码 +func (h *UserHandler) UpdatePass(c *gin.Context) { + var data struct { + OldPass string `json:"old_pass"` + Password string `json:"password"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + if len(data.Password) < 8 { + resp.ERROR(c, "密码长度不能少于8个字符") + return + } + + user, err := h.GetLoginUser(c) + if err != nil { + resp.NotAuth(c) + return + } + + password := utils.GenPassword(data.OldPass, user.Salt) + logger.Debugf(user.Salt, ",", user.Password, ",", password, ",", data.OldPass) + if password != user.Password { + resp.ERROR(c, "原密码错误") + return + } + + newPass := utils.GenPassword(data.Password, user.Salt) + res := h.DB.Model(&user).UpdateColumn("password", newPass) + if res.Error != nil { + logger.Error("error with update database:", res.Error) + resp.ERROR(c, "更新数据库失败") + return + } + + resp.SUCCESS(c) +} + +// ResetPass 重置密码 +func (h *UserHandler) ResetPass(c *gin.Context) { + var data struct { + Username string `json:"username"` + Code string `json:"code"` // 验证码 + Password string `json:"password"` // 新密码 + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + var user model.User + res := h.DB.Where("username", data.Username).First(&user) + if res.Error != nil { + resp.ERROR(c, "用户不存在!") + return + } + + // 检查验证码 + key := CodeStorePrefix + data.Username + code, err := h.redis.Get(c, key).Result() + if err != nil || code != data.Code { + resp.ERROR(c, "短信验证码错误") + return + } + + password := utils.GenPassword(data.Password, user.Salt) + user.Password = password + res = h.DB.Updates(&user) + if res.Error != nil { + resp.ERROR(c) + } else { + h.redis.Del(c, key) + resp.SUCCESS(c) + } +} + +// BindUsername 重置账号 +func (h *UserHandler) BindUsername(c *gin.Context) { + var data struct { + Username string `json:"username"` + Code string `json:"code"` + } + if err := c.ShouldBindJSON(&data); err != nil { + resp.ERROR(c, types.InvalidArgs) + return + } + + // 检查验证码 + key := CodeStorePrefix + data.Username + code, err := h.redis.Get(c, key).Result() + if err != nil || code != data.Code { + resp.ERROR(c, "验证码错误") + return + } + + // 检查手机号是否被其他账号绑定 + var item model.User + res := h.DB.Where("username = ?", data.Username).First(&item) + if res.Error == nil { + resp.ERROR(c, "该账号已经被其他账号绑定") + return + } + + user, err := h.GetLoginUser(c) + if err != nil { + resp.NotAuth(c) + return + } + + res = h.DB.Model(&user).UpdateColumn("username", data.Username) + if res.Error != nil { + logger.Error(res.Error) + resp.ERROR(c, "更新数据库失败") + return + } + + _ = h.redis.Del(c, key) // 删除短信验证码 + resp.SUCCESS(c) +} diff --git a/logger/logger.go b/logger/logger.go new file mode 100644 index 0000000..cd50840 --- /dev/null +++ b/logger/logger.go @@ -0,0 +1,81 @@ +package logger + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "gopkg.in/natefinch/lumberjack.v2" + "os" + "strings" +) + +var logger *zap.Logger +var sugarLogger *zap.SugaredLogger + +func GetLogger() *zap.SugaredLogger { + if sugarLogger != nil { + return sugarLogger + } + + logLevel := zap.NewAtomicLevelAt(getLogLevel(os.Getenv("LOG_LEVEL"))) + encoder := getEncoder() + writerSyncer := getLogWriter() + fileCore := zapcore.NewCore(encoder, writerSyncer, logLevel) + consoleOutput := zapcore.Lock(os.Stdout) + consoleCore := zapcore.NewCore( + encoder, + consoleOutput, + logLevel, + ) + core := zapcore.NewTee(fileCore, consoleCore) + logger = zap.New(core, zap.AddCaller()) + sugarLogger = logger.Sugar() + return sugarLogger +} + +// core 三个参数之 编码 +func getEncoder() zapcore.Encoder { + encoderConfig := zapcore.EncoderConfig{ + TimeKey: "time", + LevelKey: "level", + NameKey: "logger", + CallerKey: "caller", + MessageKey: "msg", + StacktraceKey: "stacktrace", + EncodeTime: zapcore.ISO8601TimeEncoder, + EncodeDuration: zapcore.SecondsDurationEncoder, + EncodeCaller: zapcore.ShortCallerEncoder, + EncodeLevel: zapcore.CapitalLevelEncoder, + } + return zapcore.NewConsoleEncoder(encoderConfig) +} + +func getLogWriter() zapcore.WriteSyncer { + lumberJackLogger := &lumberjack.Logger{ + Filename: "logs/app.log", + MaxSize: 10, + MaxBackups: 5, + MaxAge: 30, + Compress: false, + } + return zapcore.AddSync(lumberJackLogger) +} + +func getLogLevel(level string) zapcore.Level { + switch strings.ToUpper(level) { + case "DEBUG": + return zapcore.DebugLevel + case "WARN": + return zapcore.WarnLevel + case "ERROR": + return zapcore.ErrorLevel + default: + return zapcore.InfoLevel + } +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..8b585ee --- /dev/null +++ b/main.go @@ -0,0 +1,521 @@ +package main + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "context" + "embed" + "geekai/core" + "geekai/core/types" + "geekai/handler" + "geekai/handler/admin" + "geekai/handler/chatimpl" + logger2 "geekai/logger" + "geekai/service" + "geekai/service/dalle" + "geekai/service/mj" + "geekai/service/oss" + "geekai/service/payment" + "geekai/service/sd" + "geekai/service/sms" + "geekai/service/wx" + "geekai/store" + "io" + "log" + "os" + "os/signal" + "strconv" + "syscall" + "time" + + "github.com/go-redis/redis/v8" + + "github.com/lionsoul2014/ip2region/binding/golang/xdb" + "go.uber.org/fx" + "gorm.io/gorm" +) + +var logger = logger2.GetLogger() + +//go:embed res +var xdbFS embed.FS + +// AppLifecycle 应用程序生命周期 +type AppLifecycle struct { +} + +// OnStart 应用程序启动时执行 +func (l *AppLifecycle) OnStart(context.Context) error { + logger.Info("AppLifecycle OnStart") + return nil +} + +// OnStop 应用程序停止时执行 +func (l *AppLifecycle) OnStop(context.Context) error { + logger.Info("AppLifecycle OnStop") + return nil +} + +func NewAppLifeCycle() *AppLifecycle { + return &AppLifecycle{} +} + +func main() { + configFile := os.Getenv("CONFIG_FILE") + if configFile == "" { + configFile = "config.toml" + } + debug, _ := strconv.ParseBool(os.Getenv("APP_DEBUG")) + logger.Info("Loading config file: ", configFile) + if !debug { + defer func() { + if err := recover(); err != nil { + logger.Error("Panic Error:", err) + } + }() + } + + app := fx.New( + // 初始化配置应用配置 + fx.Provide(func() *types.AppConfig { + config, err := core.LoadConfig(configFile) + if err != nil { + log.Fatal(err) + } + config.Path = configFile + if debug { + _ = core.SaveConfig(config) + } + return config + }), + // 创建应用服务 + fx.Provide(core.NewServer), + // 初始化 + fx.Invoke(func(s *core.AppServer, client *redis.Client) { + s.Init(debug, client) + }), + + // 初始化数据库 + fx.Provide(store.NewGormConfig), + fx.Provide(store.NewMysql), + fx.Provide(store.NewRedisClient), + fx.Provide(store.NewLevelDB), + + fx.Provide(func() embed.FS { + return xdbFS + }), + + // 创建 Ip2Region 查询对象 + fx.Provide(func() (*xdb.Searcher, error) { + file, err := xdbFS.Open("res/ip2region.xdb") + if err != nil { + return nil, err + } + cBuff, err := io.ReadAll(file) + if err != nil { + return nil, err + } + + return xdb.NewWithBuffer(cBuff) + }), + + // 创建控制器 + fx.Provide(handler.NewChatRoleHandler), + fx.Provide(handler.NewUserHandler), + fx.Provide(chatimpl.NewChatHandler), + fx.Provide(handler.NewUploadHandler), + fx.Provide(handler.NewSmsHandler), + fx.Provide(handler.NewRewardHandler), + fx.Provide(handler.NewCaptchaHandler), + fx.Provide(handler.NewMidJourneyHandler), + fx.Provide(handler.NewChatModelHandler), + fx.Provide(handler.NewSdJobHandler), + fx.Provide(handler.NewPaymentHandler), + fx.Provide(handler.NewOrderHandler), + fx.Provide(handler.NewProductHandler), + fx.Provide(handler.NewConfigHandler), + fx.Provide(handler.NewPowerLogHandler), + + fx.Provide(admin.NewConfigHandler), + fx.Provide(admin.NewAdminHandler), + fx.Provide(admin.NewApiKeyHandler), + fx.Provide(admin.NewUserHandler), + fx.Provide(admin.NewChatRoleHandler), + fx.Provide(admin.NewRewardHandler), + fx.Provide(admin.NewDashboardHandler), + fx.Provide(admin.NewChatModelHandler), + fx.Provide(admin.NewProductHandler), + fx.Provide(admin.NewOrderHandler), + fx.Provide(admin.NewChatHandler), + fx.Provide(admin.NewPowerLogHandler), + + // 创建服务 + fx.Provide(sms.NewSendServiceManager), + fx.Provide(func(config *types.AppConfig) *service.CaptchaService { + return service.NewCaptchaService(config.ApiConfig) + }), + fx.Provide(oss.NewUploaderManager), + fx.Provide(mj.NewService), + fx.Provide(dalle.NewService), + fx.Invoke(func(service *dalle.Service) { + service.Run() + service.CheckTaskNotify() + service.DownloadImages() + service.CheckTaskStatus() + }), + + // 邮件服务 + fx.Provide(service.NewSmtpService), + // License 服务 + fx.Provide(service.NewLicenseService), + fx.Invoke(func(licenseService *service.LicenseService) { + licenseService.SyncLicense() + }), + + // 微信机器人服务 + fx.Provide(wx.NewWeChatBot), + fx.Invoke(func(config *types.AppConfig, bot *wx.Bot) { + if config.WeChatBot { + err := bot.Run() + if err != nil { + logger.Error("微信登录失败:", err) + } + } + }), + + // MidJourney service pool + fx.Provide(mj.NewServicePool), + fx.Invoke(func(pool *mj.ServicePool, config *types.AppConfig) { + pool.InitServices(config.MjPlusConfigs, config.MjProxyConfigs) + if pool.HasAvailableService() { + pool.DownloadImages() + pool.CheckTaskNotify() + pool.SyncTaskProgress() + } + }), + + // Stable Diffusion 机器人 + fx.Provide(sd.NewServicePool), + fx.Invoke(func(pool *sd.ServicePool, config *types.AppConfig) { + pool.InitServices(config.SdConfigs) + if pool.HasAvailableService() { + pool.CheckTaskNotify() + pool.CheckTaskStatus() + } + }), + + fx.Provide(payment.NewAlipayService), + fx.Provide(payment.NewHuPiPay), + fx.Provide(payment.NewPayJS), + fx.Provide(service.NewSnowflake), + fx.Provide(service.NewXXLJobExecutor), + fx.Invoke(func(exec *service.XXLJobExecutor, config *types.AppConfig) { + if config.XXLConfig.Enabled { + go func() { + log.Fatal(exec.Run()) + }() + } + }), + + // 注册路由 + fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) { + group := s.Engine.Group("/api/role/") + group.GET("list", h.List) + group.POST("update", h.UpdateRole) + }), + fx.Invoke(func(s *core.AppServer, h *handler.UserHandler) { + group := s.Engine.Group("/api/user/") + group.POST("register", h.Register) + group.POST("login", h.Login) + group.GET("logout", h.Logout) + group.GET("session", h.Session) + group.GET("profile", h.Profile) + group.POST("profile/update", h.ProfileUpdate) + group.POST("password", h.UpdatePass) + group.POST("bind/username", h.BindUsername) + group.POST("resetPass", h.ResetPass) + }), + fx.Invoke(func(s *core.AppServer, h *chatimpl.ChatHandler) { + group := s.Engine.Group("/api/chat/") + group.Any("new", h.ChatHandle) + group.GET("list", h.List) + group.GET("detail", h.Detail) + group.POST("update", h.Update) + group.GET("remove", h.Remove) + group.GET("history", h.History) + group.GET("clear", h.Clear) + group.POST("tokens", h.Tokens) + group.GET("stop", h.StopGenerate) + }), + fx.Invoke(func(s *core.AppServer, h *handler.UploadHandler) { + s.Engine.POST("/api/upload", h.Upload) + s.Engine.GET("/api/upload/list", h.List) + s.Engine.GET("/api/upload/remove", h.Remove) + }), + fx.Invoke(func(s *core.AppServer, h *handler.SmsHandler) { + group := s.Engine.Group("/api/sms/") + group.POST("code", h.SendCode) + }), + fx.Invoke(func(s *core.AppServer, h *handler.CaptchaHandler) { + group := s.Engine.Group("/api/captcha/") + group.GET("get", h.Get) + group.POST("check", h.Check) + group.GET("slide/get", h.SlideGet) + group.POST("slide/check", h.SlideCheck) + }), + fx.Invoke(func(s *core.AppServer, h *handler.RewardHandler) { + group := s.Engine.Group("/api/reward/") + group.POST("verify", h.Verify) + }), + fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) { + group := s.Engine.Group("/api/mj/") + group.Any("client", h.Client) + group.POST("image", h.Image) + group.POST("upscale", h.Upscale) + group.POST("variation", h.Variation) + group.GET("jobs", h.JobList) + group.GET("imgWall", h.ImgWall) + group.POST("remove", h.Remove) + group.POST("publish", h.Publish) + }), + fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) { + group := s.Engine.Group("/api/sd") + group.Any("client", h.Client) + group.POST("image", h.Image) + group.GET("jobs", h.JobList) + group.GET("imgWall", h.ImgWall) + group.POST("remove", h.Remove) + group.POST("publish", h.Publish) + }), + fx.Invoke(func(s *core.AppServer, h *handler.ConfigHandler) { + group := s.Engine.Group("/api/config/") + group.GET("get", h.Get) + group.GET("license", h.License) + }), + + // 管理后台控制器 + fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) { + group := s.Engine.Group("/api/admin/") + group.POST("config/update", h.Update) + group.GET("config/get", h.Get) + group.POST("active", h.Active) + group.GET("config/get/license", h.GetLicense) + group.GET("config/get/app", h.GetAppConfig) + group.POST("config/update/draw", h.SaveDrawingConfig) + }), + fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) { + group := s.Engine.Group("/api/admin/") + group.POST("login", h.Login) + group.GET("logout", h.Logout) + group.GET("session", h.Session) + group.GET("list", h.List) + group.POST("save", h.Save) + group.POST("enable", h.Enable) + group.GET("remove", h.Remove) + group.POST("resetPass", h.ResetPass) + }), + fx.Invoke(func(s *core.AppServer, h *admin.ApiKeyHandler) { + group := s.Engine.Group("/api/admin/apikey/") + group.POST("save", h.Save) + group.GET("list", h.List) + group.POST("set", h.Set) + group.GET("remove", h.Remove) + }), + fx.Invoke(func(s *core.AppServer, h *admin.UserHandler) { + group := s.Engine.Group("/api/admin/user/") + group.GET("list", h.List) + group.POST("save", h.Save) + group.GET("remove", h.Remove) + group.GET("loginLog", h.LoginLog) + group.POST("resetPass", h.ResetPass) + }), + fx.Invoke(func(s *core.AppServer, h *admin.ChatRoleHandler) { + group := s.Engine.Group("/api/admin/role/") + group.GET("list", h.List) + group.POST("save", h.Save) + group.POST("sort", h.Sort) + group.POST("set", h.Set) + group.GET("remove", h.Remove) + }), + fx.Invoke(func(s *core.AppServer, h *admin.RewardHandler) { + group := s.Engine.Group("/api/admin/reward/") + group.GET("list", h.List) + group.POST("remove", h.Remove) + }), + fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) { + group := s.Engine.Group("/api/admin/dashboard/") + group.GET("stats", h.Stats) + }), + fx.Invoke(func(s *core.AppServer, h *handler.ChatModelHandler) { + group := s.Engine.Group("/api/model/") + group.GET("list", h.List) + }), + fx.Invoke(func(s *core.AppServer, h *admin.ChatModelHandler) { + group := s.Engine.Group("/api/admin/model/") + group.POST("save", h.Save) + group.GET("list", h.List) + group.POST("set", h.Set) + group.POST("sort", h.Sort) + group.GET("remove", h.Remove) + }), + fx.Invoke(func(s *core.AppServer, h *handler.PaymentHandler) { + group := s.Engine.Group("/api/payment/") + group.GET("doPay", h.DoPay) + group.GET("payWays", h.GetPayWays) + group.POST("query", h.OrderQuery) + group.POST("qrcode", h.PayQrcode) + group.POST("mobile", h.Mobile) + group.POST("alipay/notify", h.AlipayNotify) + group.POST("hupipay/notify", h.HuPiPayNotify) + group.POST("payjs/notify", h.PayJsNotify) + }), + fx.Invoke(func(s *core.AppServer, h *admin.ProductHandler) { + group := s.Engine.Group("/api/admin/product/") + group.POST("save", h.Save) + group.GET("list", h.List) + group.POST("enable", h.Enable) + group.POST("sort", h.Sort) + group.GET("remove", h.Remove) + }), + fx.Invoke(func(s *core.AppServer, h *admin.OrderHandler) { + group := s.Engine.Group("/api/admin/order/") + group.POST("list", h.List) + group.GET("remove", h.Remove) + }), + fx.Invoke(func(s *core.AppServer, h *handler.OrderHandler) { + group := s.Engine.Group("/api/order/") + group.POST("list", h.List) + }), + fx.Invoke(func(s *core.AppServer, h *handler.ProductHandler) { + group := s.Engine.Group("/api/product/") + group.GET("list", h.List) + }), + + fx.Provide(handler.NewInviteHandler), + fx.Invoke(func(s *core.AppServer, h *handler.InviteHandler) { + group := s.Engine.Group("/api/invite/") + group.GET("code", h.Code) + group.POST("list", h.List) + group.GET("hits", h.Hits) + }), + + fx.Provide(admin.NewFunctionHandler), + fx.Invoke(func(s *core.AppServer, h *admin.FunctionHandler) { + group := s.Engine.Group("/api/admin/function/") + group.POST("save", h.Save) + group.POST("set", h.Set) + group.GET("list", h.List) + group.GET("remove", h.Remove) + group.GET("token", h.GenToken) + }), + + // 验证码 + fx.Provide(admin.NewCaptchaHandler), + fx.Invoke(func(s *core.AppServer, h *admin.CaptchaHandler) { + group := s.Engine.Group("/api/admin/login/") + group.GET("captcha", h.GetCaptcha) + }), + + fx.Provide(admin.NewUploadHandler), + fx.Invoke(func(s *core.AppServer, h *admin.UploadHandler) { + s.Engine.POST("/api/admin/upload", h.Upload) + }), + + fx.Provide(handler.NewFunctionHandler), + fx.Invoke(func(s *core.AppServer, h *handler.FunctionHandler) { + group := s.Engine.Group("/api/function/") + group.POST("weibo", h.WeiBo) + group.POST("zaobao", h.ZaoBao) + group.POST("dalle3", h.Dall3) + }), + fx.Invoke(func(s *core.AppServer, h *admin.ChatHandler) { + group := s.Engine.Group("/api/admin/chat/") + group.POST("list", h.List) + group.POST("message", h.Messages) + group.GET("history", h.History) + group.GET("remove", h.RemoveChat) + group.GET("message/remove", h.RemoveMessage) + }), + fx.Invoke(func(s *core.AppServer, h *handler.PowerLogHandler) { + group := s.Engine.Group("/api/powerLog/") + group.POST("list", h.List) + }), + fx.Invoke(func(s *core.AppServer, h *admin.PowerLogHandler) { + group := s.Engine.Group("/api/admin/powerLog/") + group.POST("list", h.List) + }), + fx.Provide(admin.NewMenuHandler), + fx.Invoke(func(s *core.AppServer, h *admin.MenuHandler) { + group := s.Engine.Group("/api/admin/menu/") + group.POST("save", h.Save) + group.GET("list", h.List) + group.POST("enable", h.Enable) + group.POST("sort", h.Sort) + group.GET("remove", h.Remove) + }), + fx.Provide(handler.NewMenuHandler), + fx.Invoke(func(s *core.AppServer, h *handler.MenuHandler) { + group := s.Engine.Group("/api/menu/") + group.GET("list", h.List) + }), + fx.Provide(handler.NewMarkMapHandler), + fx.Invoke(func(s *core.AppServer, h *handler.MarkMapHandler) { + group := s.Engine.Group("/api/markMap/") + group.Any("client", h.Client) + }), + fx.Provide(handler.NewDallJobHandler), + fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) { + group := s.Engine.Group("/api/dall") + group.Any("client", h.Client) + group.POST("image", h.Image) + group.GET("jobs", h.JobList) + group.GET("imgWall", h.ImgWall) + group.POST("remove", h.Remove) + group.POST("publish", h.Publish) + }), + fx.Invoke(func(s *core.AppServer, db *gorm.DB) { + go func() { + err := s.Run(db) + if err != nil { + log.Fatal(err) + } + }() + }), + fx.Provide(NewAppLifeCycle), + // 注册生命周期回调函数 + fx.Invoke(func(lifecycle fx.Lifecycle, lc *AppLifecycle) { + lifecycle.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + return lc.OnStart(ctx) + }, + OnStop: func(ctx context.Context) error { + return lc.OnStop(ctx) + }, + }) + }), + ) + // 启动应用程序 + go func() { + if err := app.Start(context.Background()); err != nil { + log.Fatal(err) + } + }() + + // 监听退出信号 + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + + // 关闭应用程序 + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := app.Stop(ctx); err != nil { + log.Fatal(err) + } + +} diff --git a/res/certs/alipay/alipayPublicCert.crt b/res/certs/alipay/alipayPublicCert.crt new file mode 100644 index 0000000..82013d6 --- /dev/null +++ b/res/certs/alipay/alipayPublicCert.crt @@ -0,0 +1,38 @@ +-----BEGIN CERTIFICATE----- +MIIDszCCApugAwIBAgIQICMRB0rBU2/rZJbfJGMYIzANBgkqhkiG9w0BAQsFADCBkTELMAkGA1UE +BhMCQ04xGzAZBgNVBAoMEkFudCBGaW5hbmNpYWwgdGVzdDElMCMGA1UECwwcQ2VydGlmaWNhdGlv +biBBdXRob3JpdHkgdGVzdDE+MDwGA1UEAww1QW50IEZpbmFuY2lhbCBDZXJ0aWZpY2F0aW9uIEF1 +dGhvcml0eSBDbGFzcyAyIFIxIHRlc3QwHhcNMjMxMTA3MDYzNTQxWhcNMjQxMTA2MDYzNTQxWjCB +hDELMAkGA1UEBhMCQ04xHzAdBgNVBAoMFm1ib25meTkwMTVAc2FuZGJveC5jb20xDzANBgNVBAsM +BkFsaXBheTFDMEEGA1UEAww65pSv5LuY5a6dKOS4reWbvSnnvZHnu5zmioDmnK/mnInpmZDlhazl +j7gtMjA4ODcyMTAyMDc1MDU4MTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKsoKcw5 +sxaiyV7mpWzDtnQ1K518eQLP0+dJlZAf06aBep/Aj9DIqrba/k7DHt8dKQvILMLAMpN1+2IRxbaO +yxMa/laj3lZ1eHrB6F077O3D62oHcE3noZtXL0N1zZAxpmkNmYIHeLZS2oLMS4ANu47O/wpDC7BV +HjdpZugtdPJ4mxdCpM9GDdLs7W4s5QI4PUPK4skFNMFoKI+0cYP/9ju87UP//IHC/K510GWNl+Gn +Cvgag3AmiIB0utJNsGhxm6zT1T9tUWjW9iz/BxBKiPatsCX9VpPQzGnW7ZonRQtiZSokIlP2IPvl +H5DcwpWUz3/LUY0SmKxnKOEYeOOqCW8CAwEAAaMSMBAwDgYDVR0PAQH/BAQDAgTwMA0GCSqGSIb3 +DQEBCwUAA4IBAQAtgxF2EzjOndEFxBUD9tFwcSt6XKGggOp52oft1pvynPg4ALTLafOtfEPDrFBH +PwpYrSu9s9C8NJtaA2HrlCfBjIuwEFTXiN+HPvS0SwSPKt9AXEiTcOF8vDcGamEen8QI4fo5Jia7 +2VRKkerkww5/+FzSaVO7ZUKuL80M1QJStmAZc8kPPwdYOTTW2bGf8BcmSDL6SPElBkt7tCCRd4sn ++jq4cZ0yb2i77rBZCwHcTvfTqIBblPwLv4uGvg3+83BxIB5w6Kqp06bKEAPmobFY5IVHa+ON0/qi +BXxXr+WQ3piKRVQEN64+PTAjSc67Ix1umvpLl3Ko6Ry7NJmpDcUn +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIDszCCApugAwIBAgIQIBkIGbgVxq210KxLJ+YA/TANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UE +BhMCQ04xFjAUBgNVBAoMDUFudCBGaW5hbmNpYWwxJTAjBgNVBAsMHENlcnRpZmljYXRpb24gQXV0 +aG9yaXR5IHRlc3QxNjA0BgNVBAMMLUFudCBGaW5hbmNpYWwgQ2VydGlmaWNhdGlvbiBBdXRob3Jp +dHkgUjEgdGVzdDAeFw0xOTA4MTkxMTE2MDBaFw0yNDA4MDExMTE2MDBaMIGRMQswCQYDVQQGEwJD +TjEbMBkGA1UECgwSQW50IEZpbmFuY2lhbCB0ZXN0MSUwIwYDVQQLDBxDZXJ0aWZpY2F0aW9uIEF1 +dGhvcml0eSB0ZXN0MT4wPAYDVQQDDDVBbnQgRmluYW5jaWFsIENlcnRpZmljYXRpb24gQXV0aG9y +aXR5IENsYXNzIDIgUjEgdGVzdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMh4FKYO +ZyRQHD6eFbPKZeSAnrfjfU7xmS9Yoozuu+iuqZlb6Z0SPLUqqTZAFZejOcmr07ln/pwZxluqplxC +5+B48End4nclDMlT5HPrDr3W0frs6Xsa2ZNcyil/iKNB5MbGll8LRAxntsKvZZj6vUTMb705gYgm +VUMILwi/ZxKTQqBtkT/kQQ5y6nOZsj7XI5rYdz6qqOROrpvS/d7iypdHOMIM9Iz9DlL1mrCykbBi +t25y+gTeXmuisHUwqaRpwtCGK4BayCqxRGbNipe6W73EK9lBrrzNtTr9NaysesT/v+l25JHCL9tG +wpNr1oWFzk4IHVOg0ORiQ6SUgxZUTYcCAwEAAaMSMBAwDgYDVR0PAQH/BAQDAgTwMA0GCSqGSIb3 +DQEBCwUAA4IBAQBWThEoIaQoBX2YeRY/I8gu6TYnFXtyuCljANnXnM38ft+ikhE5mMNgKmJYLHvT +yWWWgwHoSAWEuml7EGbE/2AK2h3k0MdfiWLzdmpPCRG/RJHk6UB1pMHPilI+c0MVu16OPpKbg5Vf +LTv7dsAB40AzKsvyYw88/Ezi1osTXo6QQwda7uefvudirtb8FcQM9R66cJxl3kt1FXbpYwheIm/p +j1mq64swCoIYu4NrsUYtn6CV542DTQMI5QdXkn+PzUUly8F6kDp+KpMNd0avfWNL5+O++z+F5Szy +1CPta1D7EQ/eYmMP+mOQ35oifWIoFCpN6qQVBS/Hob1J/UUyg7BW +-----END CERTIFICATE----- diff --git a/res/certs/alipay/alipayRootCert.crt b/res/certs/alipay/alipayRootCert.crt new file mode 100644 index 0000000..76417c5 --- /dev/null +++ b/res/certs/alipay/alipayRootCert.crt @@ -0,0 +1,88 @@ +-----BEGIN CERTIFICATE----- +MIIBszCCAVegAwIBAgIIaeL+wBcKxnswDAYIKoEcz1UBg3UFADAuMQswCQYDVQQG +EwJDTjEOMAwGA1UECgwFTlJDQUMxDzANBgNVBAMMBlJPT1RDQTAeFw0xMjA3MTQw +MzExNTlaFw00MjA3MDcwMzExNTlaMC4xCzAJBgNVBAYTAkNOMQ4wDAYDVQQKDAVO +UkNBQzEPMA0GA1UEAwwGUk9PVENBMFkwEwYHKoZIzj0CAQYIKoEcz1UBgi0DQgAE +MPCca6pmgcchsTf2UnBeL9rtp4nw+itk1Kzrmbnqo05lUwkwlWK+4OIrtFdAqnRT +V7Q9v1htkv42TsIutzd126NdMFswHwYDVR0jBBgwFoAUTDKxl9kzG8SmBcHG5Yti +W/CXdlgwDAYDVR0TBAUwAwEB/zALBgNVHQ8EBAMCAQYwHQYDVR0OBBYEFEwysZfZ +MxvEpgXBxuWLYlvwl3ZYMAwGCCqBHM9VAYN1BQADSAAwRQIgG1bSLeOXp3oB8H7b +53W+CKOPl2PknmWEq/lMhtn25HkCIQDaHDgWxWFtnCrBjH16/W3Ezn7/U/Vjo5xI +pDoiVhsLwg== +-----END CERTIFICATE----- + +-----BEGIN CERTIFICATE----- +MIIF0zCCA7ugAwIBAgIIH8+hjWpIDREwDQYJKoZIhvcNAQELBQAwejELMAkGA1UE +BhMCQ04xFjAUBgNVBAoMDUFudCBGaW5hbmNpYWwxIDAeBgNVBAsMF0NlcnRpZmlj +YXRpb24gQXV0aG9yaXR5MTEwLwYDVQQDDChBbnQgRmluYW5jaWFsIENlcnRpZmlj +YXRpb24gQXV0aG9yaXR5IFIxMB4XDTE4MDMyMTEzNDg0MFoXDTM4MDIyODEzNDg0 +MFowejELMAkGA1UEBhMCQ04xFjAUBgNVBAoMDUFudCBGaW5hbmNpYWwxIDAeBgNV +BAsMF0NlcnRpZmljYXRpb24gQXV0aG9yaXR5MTEwLwYDVQQDDChBbnQgRmluYW5j +aWFsIENlcnRpZmljYXRpb24gQXV0aG9yaXR5IFIxMIICIjANBgkqhkiG9w0BAQEF +AAOCAg8AMIICCgKCAgEAtytTRcBNuur5h8xuxnlKJetT65cHGemGi8oD+beHFPTk +rUTlFt9Xn7fAVGo6QSsPb9uGLpUFGEdGmbsQ2q9cV4P89qkH04VzIPwT7AywJdt2 +xAvMs+MgHFJzOYfL1QkdOOVO7NwKxH8IvlQgFabWomWk2Ei9WfUyxFjVO1LVh0Bp +dRBeWLMkdudx0tl3+21t1apnReFNQ5nfX29xeSxIhesaMHDZFViO/DXDNW2BcTs6 +vSWKyJ4YIIIzStumD8K1xMsoaZBMDxg4itjWFaKRgNuPiIn4kjDY3kC66Sl/6yTl +YUz8AybbEsICZzssdZh7jcNb1VRfk79lgAprm/Ktl+mgrU1gaMGP1OE25JCbqli1 +Pbw/BpPynyP9+XulE+2mxFwTYhKAwpDIDKuYsFUXuo8t261pCovI1CXFzAQM2w7H +DtA2nOXSW6q0jGDJ5+WauH+K8ZSvA6x4sFo4u0KNCx0ROTBpLif6GTngqo3sj+98 +SZiMNLFMQoQkjkdN5Q5g9N6CFZPVZ6QpO0JcIc7S1le/g9z5iBKnifrKxy0TQjtG +PsDwc8ubPnRm/F82RReCoyNyx63indpgFfhN7+KxUIQ9cOwwTvemmor0A+ZQamRe +9LMuiEfEaWUDK+6O0Gl8lO571uI5onYdN1VIgOmwFbe+D8TcuzVjIZ/zvHrAGUcC +AwEAAaNdMFswCwYDVR0PBAQDAgEGMAwGA1UdEwQFMAMBAf8wHQYDVR0OBBYEFF90 +tATATwda6uWx2yKjh0GynOEBMB8GA1UdIwQYMBaAFF90tATATwda6uWx2yKjh0Gy +nOEBMA0GCSqGSIb3DQEBCwUAA4ICAQCVYaOtqOLIpsrEikE5lb+UARNSFJg6tpkf +tJ2U8QF/DejemEHx5IClQu6ajxjtu0Aie4/3UnIXop8nH/Q57l+Wyt9T7N2WPiNq +JSlYKYbJpPF8LXbuKYG3BTFTdOVFIeRe2NUyYh/xs6bXGr4WKTXb3qBmzR02FSy3 +IODQw5Q6zpXj8prYqFHYsOvGCEc1CwJaSaYwRhTkFedJUxiyhyB5GQwoFfExCVHW +05ZFCAVYFldCJvUzfzrWubN6wX0DD2dwultgmldOn/W/n8at52mpPNvIdbZb2F41 +T0YZeoWnCJrYXjq/32oc1cmifIHqySnyMnavi75DxPCdZsCOpSAT4j4lAQRGsfgI +kkLPGQieMfNNkMCKh7qjwdXAVtdqhf0RVtFILH3OyEodlk1HYXqX5iE5wlaKzDop +PKwf2Q3BErq1xChYGGVS+dEvyXc/2nIBlt7uLWKp4XFjqekKbaGaLJdjYP5b2s7N +1dM0MXQ/f8XoXKBkJNzEiM3hfsU6DOREgMc1DIsFKxfuMwX3EkVQM1If8ghb6x5Y +jXayv+NLbidOSzk4vl5QwngO/JYFMkoc6i9LNwEaEtR9PhnrdubxmrtM+RjfBm02 +77q3dSWFESFQ4QxYWew4pHE0DpWbWy/iMIKQ6UZ5RLvB8GEcgt8ON7BBJeMc+Dyi +kT9qhqn+lw== +-----END CERTIFICATE----- + +-----BEGIN CERTIFICATE----- +MIICiDCCAgygAwIBAgIIQX76UsB/30owDAYIKoZIzj0EAwMFADB6MQswCQYDVQQG +EwJDTjEWMBQGA1UECgwNQW50IEZpbmFuY2lhbDEgMB4GA1UECwwXQ2VydGlmaWNh +dGlvbiBBdXRob3JpdHkxMTAvBgNVBAMMKEFudCBGaW5hbmNpYWwgQ2VydGlmaWNh +dGlvbiBBdXRob3JpdHkgRTEwHhcNMTkwNDI4MTYyMDQ0WhcNNDkwNDIwMTYyMDQ0 +WjB6MQswCQYDVQQGEwJDTjEWMBQGA1UECgwNQW50IEZpbmFuY2lhbDEgMB4GA1UE +CwwXQ2VydGlmaWNhdGlvbiBBdXRob3JpdHkxMTAvBgNVBAMMKEFudCBGaW5hbmNp +YWwgQ2VydGlmaWNhdGlvbiBBdXRob3JpdHkgRTEwdjAQBgcqhkjOPQIBBgUrgQQA +IgNiAASCCRa94QI0vR5Up9Yr9HEupz6hSoyjySYqo7v837KnmjveUIUNiuC9pWAU +WP3jwLX3HkzeiNdeg22a0IZPoSUCpasufiLAnfXh6NInLiWBrjLJXDSGaY7vaokt +rpZvAdmjXTBbMAsGA1UdDwQEAwIBBjAMBgNVHRMEBTADAQH/MB0GA1UdDgQWBBRZ +4ZTgDpksHL2qcpkFkxD2zVd16TAfBgNVHSMEGDAWgBRZ4ZTgDpksHL2qcpkFkxD2 +zVd16TAMBggqhkjOPQQDAwUAA2gAMGUCMQD4IoqT2hTUn0jt7oXLdMJ8q4vLp6sg +wHfPiOr9gxreb+e6Oidwd2LDnC4OUqCWiF8CMAzwKs4SnDJYcMLf2vpkbuVE4dTH +Rglz+HGcTLWsFs4KxLsq7MuU+vJTBUeDJeDjdA== +-----END CERTIFICATE----- + +-----BEGIN CERTIFICATE----- +MIIDxTCCAq2gAwIBAgIUEMdk6dVgOEIS2cCP0Q43P90Ps5YwDQYJKoZIhvcNAQEF +BQAwajELMAkGA1UEBhMCQ04xEzARBgNVBAoMCmlUcnVzQ2hpbmExHDAaBgNVBAsM +E0NoaW5hIFRydXN0IE5ldHdvcmsxKDAmBgNVBAMMH2lUcnVzQ2hpbmEgQ2xhc3Mg +MiBSb290IENBIC0gRzMwHhcNMTMwNDE4MDkzNjU2WhcNMzMwNDE4MDkzNjU2WjBq +MQswCQYDVQQGEwJDTjETMBEGA1UECgwKaVRydXNDaGluYTEcMBoGA1UECwwTQ2hp +bmEgVHJ1c3QgTmV0d29yazEoMCYGA1UEAwwfaVRydXNDaGluYSBDbGFzcyAyIFJv +b3QgQ0EgLSBHMzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAOPPShpV +nJbMqqCw6Bz1kehnoPst9pkr0V9idOwU2oyS47/HjJXk9Rd5a9xfwkPO88trUpz5 +4GmmwspDXjVFu9L0eFaRuH3KMha1Ak01citbF7cQLJlS7XI+tpkTGHEY5pt3EsQg +wykfZl/A1jrnSkspMS997r2Gim54cwz+mTMgDRhZsKK/lbOeBPpWtcFizjXYCqhw +WktvQfZBYi6o4sHCshnOswi4yV1p+LuFcQ2ciYdWvULh1eZhLxHbGXyznYHi0dGN +z+I9H8aXxqAQfHVhbdHNzi77hCxFjOy+hHrGsyzjrd2swVQ2iUWP8BfEQqGLqM1g +KgWKYfcTGdbPB1MCAwEAAaNjMGEwHQYDVR0OBBYEFG/oAMxTVe7y0+408CTAK8hA +uTyRMB8GA1UdIwQYMBaAFG/oAMxTVe7y0+408CTAK8hAuTyRMA8GA1UdEwEB/wQF +MAMBAf8wDgYDVR0PAQH/BAQDAgEGMA0GCSqGSIb3DQEBBQUAA4IBAQBLnUTfW7hp +emMbuUGCk7RBswzOT83bDM6824EkUnf+X0iKS95SUNGeeSWK2o/3ALJo5hi7GZr3 +U8eLaWAcYizfO99UXMRBPw5PRR+gXGEronGUugLpxsjuynoLQu8GQAeysSXKbN1I +UugDo9u8igJORYA+5ms0s5sCUySqbQ2R5z/GoceyI9LdxIVa1RjVX8pYOj8JFwtn +DJN3ftSFvNMYwRuILKuqUYSHc2GPYiHVflDh5nDymCMOQFcFG3WsEuB+EYQPFgIU +1DHmdZcz7Llx8UOZXX2JupWCYzK1XhJb+r4hK5ncf/w8qGtYlmyJpxk3hr1TfUJX +Yf4Zr0fJsGuv +-----END CERTIFICATE----- \ No newline at end of file diff --git a/res/certs/alipay/appPublicCert.crt b/res/certs/alipay/appPublicCert.crt new file mode 100644 index 0000000..ddc9029 --- /dev/null +++ b/res/certs/alipay/appPublicCert.crt @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDmTCCAoGgAwIBAgIQICMRB2LW76yahgdg3IFNPDANBgkqhkiG9w0BAQsFADCBkTELMAkGA1UE +BhMCQ04xGzAZBgNVBAoMEkFudCBGaW5hbmNpYWwgdGVzdDElMCMGA1UECwwcQ2VydGlmaWNhdGlv +biBBdXRob3JpdHkgdGVzdDE+MDwGA1UEAww1QW50IEZpbmFuY2lhbCBDZXJ0aWZpY2F0aW9uIEF1 +dGhvcml0eSBDbGFzcyAyIFIxIHRlc3QwHhcNMjMxMTA3MDU0NjE5WhcNMjQxMTExMDU0NjE5WjBr +MQswCQYDVQQGEwJDTjEfMB0GA1UECgwWbWJvbmZ5OTAxNUBzYW5kYm94LmNvbTEPMA0GA1UECwwG +QWxpcGF5MSowKAYDVQQDDCEyMDg4NzIxMDIwNzUwNTgxLTkwMjEwMDAxMzE2NTgwMjMwggEiMA0G +CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCxihQPf1Q+g9ArgM46shVqL5sbRha/df95D1PsWyEq +ANmWmG4zZ+ksYDVQrc4KzhSRoi56sm/7TDFYTmM6bW99e/nKW58WxyZB4ie5qA3F4n17psPyDqb8 +IokcQmCphSFDaXQD6AoXoLNtTM0vAI2cWxAgebZ/vsrdj5Ntjt+Rp3NYMCk1i5xovHcfILzLEGbX +QXoT9fo5AhHotTWa6xHVLPUGY9qwLzQxHzBmvy5ZMfnOfJkm/mDisTSqAUB59F3dzU/1ARVkEZ1w +Mgb4XohWBw6iurQfbMnH2mIomAAwwZVFv+sXDbL9yMbSMo/SjVsTQprn0Q0EnwLo7nmmOM6HAgMB +AAGjEjAQMA4GA1UdDwEB/wQEAwIE8DANBgkqhkiG9w0BAQsFAAOCAQEAn3Y4/C1h9R6ONsBqX3/q +XfHX7yX1FM0Y1x48X3/Yxk6HivAkTukhhhVYVKJsbrbzRqHDp9vhAP/FR6o6pAevaYMmLov0VMXU +7oAuetgkaYEYkDuNen5/Hpdhqi2vTtdT+q9w8zHJd6MDQ0aoHgIxpLKw5vof2R1N4fwSgNXMiXE5 +kmllKQMem/+on2p+Sj80/2asxryHIGlH87qPzkffv+kIOkZthbTApTFLLjdVri2QHGe8/cc4xy01 +/9iR3IUzNahotT41lJ4bMevBY7XMAS3n5ekyABN/9ZRJqhWdXgmFCRN/u56qd6lDgu7R2M2QUoyc +LuW5DfgRItKlmUB7sw== +-----END CERTIFICATE----- \ No newline at end of file diff --git a/res/certs/alipay/privateKey.txt b/res/certs/alipay/privateKey.txt new file mode 100644 index 0000000..562dae4 --- /dev/null +++ b/res/certs/alipay/privateKey.txt @@ -0,0 +1 @@ +MIIEpQIBAAKCAQEAsYoUD39UPoPQK4DOOrIVai+bG0YWv3X/eQ9T7FshKgDZlphuM2fpLGA1UK3OCs4UkaIuerJv+0wxWE5jOm1vfXv5ylufFscmQeInuagNxeJ9e6bD8g6m/CKJHEJgqYUhQ2l0A+gKF6CzbUzNLwCNnFsQIHm2f77K3Y+TbY7fkadzWDApNYucaLx3HyC8yxBm10F6E/X6OQIR6LU1musR1Sz1BmPasC80MR8wZr8uWTH5znyZJv5g4rE0qgFAefRd3c1P9QEVZBGdcDIG+F6IVgcOorq0H2zJx9piKJgAMMGVRb/rFw2y/cjG0jKP0o1bE0Ka59ENBJ8C6O55pjjOhwIDAQABAoIBAFetNfz1R7hbxjlFshMAkVzQR8wvT9qbvl+dtzdZRcaFhu89NecDIP7+QDYor0FcxoGpU0TazDyRQyk2BQD8vHt+9zv9BVLtZLJSqoWgPbUFBi1DjS8EF2ka8RVYnn35NhUhhd7L//ftL88Bh673mfembQ9srDjoEy1Z01feoABAnCMkNFl986DmEwnarvEufXSDIgeN4ioMxha4NvfIPuI0zpVdV1O9sv+SGC+VEWZBtN3GNsaf4zS/f8FVGvTiU/Abz0gSw/iwSPHclDWQDTN3yFHf/tfqlzh0mH0WfhnuOBFWXzK+R7fbnM+asI9ttvzRcfpzgRGXdPcNcOv/6cECgYEA3DVqpi1k8MYfJixju6SG5gfyhM4VFksFmCMaNPgtatDMBKLMTgV/Ej6LXREojcy29uZl83F09pVlpd41eG39ULIPktixA/BqErQ2UaWh6kOxifycpu22Jh0r09hax6UgVrcBrrnCJEjcFsuJlrZvXQSzc3PBxjWy5gjabS5h9iECgYEAzmVAIh2frF01Y95zsLueAhhZwCtPanm6kf7ivR4r1plIX3b2sNRhWGmEHFgaCE6Braa0ogQ73Hd26kw4ZW+D6QMGC/zjCBEzDLLf++SjdVUHiY5AR4WHqXzq1jdAlsVyo9R661oAOp3lhiJVGLNXkHyEfEVPHsaxJh4osYSbX6cCgYEAx32Qx0i6eDFTyLZQB46uMrgiaVN04QRH5iJuvGvUYT8UhGKjaU8rZfDJOh+wOH2rhxMEaz1uc3C2bERY9mfWI4Ob/jFWc7YZsiYWS3Mcsuhubw4tMECLUg39RWZsHw8ls8kIuixIh6yFzhTH6YQOcRswIrhMZG8DScfdcSmiz2ECgYEAkWP1t5KSpkLKl11etcKUXfl1T8+yk9jIOowIgRw92WAFAWq2AH67TCKYM7dEL1HOO9tRJ0hAOt/U3ttuZtYVYBEHM26jJ02mXm2rJrA7DS4mrxmL4lYH6LbcXqZxU0Qnq4zEQgIWYzRTORf6Rfof1uJAGaJhR9bDd4yLMfGt2cUCgYEAo216Y61xOHUTA4AF1eekk+r+uOcQgQDvLXfs9FkDdJLk0mPG48/+eIYpPFnANJ/riF/DWOp8WGEe2IzA9yUFexzDbNQK8ha9kGcxaSAyiCwzjZ/t9/+hScDSV8kNqWSRSisu/YOFleEHbokT6mbLZ+gdqES8mUUanaEBzRQYGxo= \ No newline at end of file diff --git a/res/img/alipay.jpg b/res/img/alipay.jpg new file mode 100644 index 0000000..af7b406 Binary files /dev/null and b/res/img/alipay.jpg differ diff --git a/res/img/wechat-pay.jpg b/res/img/wechat-pay.jpg new file mode 100644 index 0000000..db39839 Binary files /dev/null and b/res/img/wechat-pay.jpg differ diff --git a/res/ip2region.xdb b/res/ip2region.xdb new file mode 100644 index 0000000..c78b792 Binary files /dev/null and b/res/ip2region.xdb differ diff --git a/service/captcha_service.go b/service/captcha_service.go new file mode 100644 index 0000000..864e939 --- /dev/null +++ b/service/captcha_service.go @@ -0,0 +1,110 @@ +package service + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "errors" + "fmt" + "geekai/core/types" + "github.com/imroc/req/v3" + "time" +) + +type CaptchaService struct { + config types.ApiConfig + client *req.Client +} + +func NewCaptchaService(config types.ApiConfig) *CaptchaService { + return &CaptchaService{ + config: config, + client: req.C().SetTimeout(10 * time.Second), + } +} + +func (s *CaptchaService) Get() (interface{}, error) { + if s.config.Token == "" { + return nil, errors.New("无效的 API Token") + } + + url := fmt.Sprintf("%s/api/captcha/get", s.config.ApiURL) + var res types.BizVo + r, err := s.client.R(). + SetHeader("AppId", s.config.AppId). + SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)). + SetSuccessResult(&res).Get(url) + if err != nil || r.IsErrorState() { + return nil, fmt.Errorf("请求 API 失败:%v", err) + } + + if res.Code != types.Success { + return nil, fmt.Errorf("请求 API 失败:%s", res.Message) + } + + return res.Data, nil +} + +func (s *CaptchaService) Check(data interface{}) bool { + url := fmt.Sprintf("%s/api/captcha/check", s.config.ApiURL) + var res types.BizVo + r, err := s.client.R(). + SetHeader("AppId", s.config.AppId). + SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)). + SetBodyJsonMarshal(data). + SetSuccessResult(&res).Post(url) + if err != nil || r.IsErrorState() { + return false + } + + if res.Code != types.Success { + return false + } + + return true +} + +func (s *CaptchaService) SlideGet() (interface{}, error) { + if s.config.Token == "" { + return nil, errors.New("无效的 API Token") + } + + url := fmt.Sprintf("%s/api/captcha/slide/get", s.config.ApiURL) + var res types.BizVo + r, err := s.client.R(). + SetHeader("AppId", s.config.AppId). + SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)). + SetSuccessResult(&res).Get(url) + if err != nil || r.IsErrorState() { + return nil, fmt.Errorf("请求 API 失败:%v", err) + } + + if res.Code != types.Success { + return nil, fmt.Errorf("请求 API 失败:%s", res.Message) + } + + return res.Data, nil +} + +func (s *CaptchaService) SlideCheck(data interface{}) bool { + url := fmt.Sprintf("%s/api/captcha/slide/check", s.config.ApiURL) + var res types.BizVo + r, err := s.client.R(). + SetHeader("AppId", s.config.AppId). + SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)). + SetBodyJsonMarshal(data). + SetSuccessResult(&res).Post(url) + if err != nil || r.IsErrorState() { + return false + } + + if res.Code != types.Success { + return false + } + + return true +} diff --git a/service/dalle/service.go b/service/dalle/service.go new file mode 100644 index 0000000..f3e813b --- /dev/null +++ b/service/dalle/service.go @@ -0,0 +1,313 @@ +package dalle + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "errors" + "fmt" + "geekai/core/types" + logger2 "geekai/logger" + "geekai/service" + "geekai/service/oss" + "geekai/service/sd" + "geekai/store" + "geekai/store/model" + "geekai/utils" + "github.com/go-redis/redis/v8" + "time" + + "github.com/imroc/req/v3" + "gorm.io/gorm" +) + +var logger = logger2.GetLogger() + +// DALL-E 绘画服务 + +type Service struct { + httpClient *req.Client + db *gorm.DB + uploadManager *oss.UploaderManager + taskQueue *store.RedisQueue + notifyQueue *store.RedisQueue + Clients *types.LMap[uint, *types.WsClient] // UserId => Client +} + +func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service { + return &Service{ + httpClient: req.C().SetTimeout(time.Minute * 3), + db: db, + taskQueue: store.NewRedisQueue("DallE_Task_Queue", redisCli), + notifyQueue: store.NewRedisQueue("DallE_Notify_Queue", redisCli), + Clients: types.NewLMap[uint, *types.WsClient](), + uploadManager: manager, + } +} + +// PushTask push a new mj task in to task queue +func (s *Service) PushTask(task types.DallTask) { + logger.Infof("add a new DALL-E task to the task list: %+v", task) + s.taskQueue.RPush(task) +} + +func (s *Service) Run() { + logger.Info("Starting DALL-E job consumer...") + go func() { + for { + var task types.DallTask + err := s.taskQueue.LPop(&task) + if err != nil { + logger.Errorf("taking task with error: %v", err) + continue + } + logger.Infof("handle a new DALL-E task: %+v", task) + _, err = s.Image(task, false) + if err != nil { + logger.Errorf("error with image task: %v", err) + s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{ + "progress": -1, + "err_msg": err.Error(), + }) + s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Failed}) + } + } + }() +} + +type imgReq struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + N int `json:"n"` + Size string `json:"size"` + Quality string `json:"quality"` + Style string `json:"style"` +} + +type imgRes struct { + Created int64 `json:"created"` + Data []struct { + RevisedPrompt string `json:"revised_prompt"` + Url string `json:"url"` + } `json:"data"` +} + +type ErrRes struct { + Error struct { + Code interface{} `json:"code"` + Message string `json:"message"` + Param interface{} `json:"param"` + Type string `json:"type"` + } `json:"error"` +} + +func (s *Service) Image(task types.DallTask, sync bool) (string, error) { + logger.Debugf("绘画参数:%+v", task) + prompt := task.Prompt + // translate prompt + if utils.HasChinese(task.Prompt) { + content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt)) + if err != nil { + return "", fmt.Errorf("error with translate prompt: %v", err) + } + prompt = content + logger.Debugf("重写后提示词:%s", prompt) + } + + var user model.User + s.db.Where("id", task.UserId).First(&user) + if user.Power < task.Power { + return "", errors.New("insufficient of power") + } + + // get image generation API KEY + var apiKey model.ApiKey + tx := s.db.Where("platform", types.OpenAI.Value). + Where("type", "img"). + Where("enabled", true). + Order("last_used_at ASC").First(&apiKey) + if tx.Error != nil { + return "", fmt.Errorf("no available IMG api key: %v", tx.Error) + } + + var res imgRes + var errRes ErrRes + if len(apiKey.ProxyURL) > 5 { + s.httpClient.SetProxyURL(apiKey.ProxyURL).R() + } + logger.Infof("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL) + r, err := s.httpClient.R().SetHeader("Content-Type", "application/json"). + SetHeader("Authorization", "Bearer "+apiKey.Value). + SetBody(imgReq{ + Model: "dall-e-3", + Prompt: prompt, + N: 1, + Size: task.Size, + Style: task.Style, + Quality: task.Quality, + }). + SetErrorResult(&errRes). + SetSuccessResult(&res).Post(apiKey.ApiURL) + if err != nil { + return "", fmt.Errorf("error with send request: %v", err) + } + + if r.IsErrorState() { + return "", fmt.Errorf("error with send request: %v", errRes.Error) + } + // update the api key last use time + s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix()) + // update task progress + tx = s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{ + "progress": 100, + "org_url": res.Data[0].Url, + "prompt": prompt, + }) + if tx.Error != nil { + return "", fmt.Errorf("err with update database: %v", tx.Error) + } + + s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Finished}) + var content string + if sync { + imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url) + if err != nil { + return "", fmt.Errorf("error with download image: %v", err) + } + content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片:\n\n![](%s)\n", prompt, imgURL) + } + + // 更新用户算力 + tx = s.db.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", task.Power)) + // 记录算力变化日志 + if tx.Error == nil && tx.RowsAffected > 0 { + var u model.User + s.db.Where("id", user.Id).First(&u) + s.db.Create(&model.PowerLog{ + UserId: user.Id, + Username: user.Username, + Type: types.PowerConsume, + Amount: task.Power, + Balance: u.Power, + Mark: types.PowerSub, + Model: "dall-e-3", + Remark: fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)), + CreatedAt: time.Now(), + }) + } + + return content, nil +} + +func (s *Service) CheckTaskNotify() { + go func() { + logger.Info("Running DALL-E task notify checking ...") + for { + var message sd.NotifyMessage + err := s.notifyQueue.LPop(&message) + if err != nil { + continue + } + client := s.Clients.Get(uint(message.UserId)) + if client == nil { + continue + } + err = client.Send([]byte(message.Message)) + if err != nil { + continue + } + } + }() +} + +func (s *Service) DownloadImages() { + go func() { + var items []model.DallJob + for { + res := s.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items) + if res.Error != nil { + continue + } + + // download images + for _, v := range items { + if v.OrgURL == "" { + continue + } + + logger.Infof("try to download image: %s", v.OrgURL) + imgURL, err := s.downloadImage(v.Id, int(v.UserId), v.OrgURL) + if err != nil { + logger.Error("error with download image: %s, error: %v", imgURL, err) + continue + } else { + logger.Infof("download image %s successfully.", v.OrgURL) + } + + } + + time.Sleep(time.Second * 5) + } + }() +} + +func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string, error) { + // sava image + imgURL, err := s.uploadManager.GetUploadHandler().PutImg(orgURL, false) + if err != nil { + return "", err + } + + // update img_url + res := s.db.Model(&model.DallJob{Id: jobId}).UpdateColumn("img_url", imgURL) + if res.Error != nil { + return "", err + } + s.notifyQueue.RPush(sd.NotifyMessage{UserId: userId, JobId: int(jobId), Message: sd.Finished}) + return imgURL, nil +} + +// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务 +func (s *Service) CheckTaskStatus() { + go func() { + logger.Info("Running Stable-Diffusion task status checking ...") + for { + var jobs []model.DallJob + res := s.db.Where("progress < ?", 100).Find(&jobs) + if res.Error != nil { + time.Sleep(5 * time.Second) + continue + } + + for _, job := range jobs { + // 5 分钟还没完成的任务直接删除 + if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 { + s.db.Delete(&job) + var user model.User + s.db.Where("id = ?", job.UserId).First(&user) + // 退回绘图次数 + res = s.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)) + if res.Error == nil && res.RowsAffected > 0 { + s.db.Create(&model.PowerLog{ + UserId: user.Id, + Username: user.Username, + Type: types.PowerConsume, + Amount: job.Power, + Balance: user.Power + job.Power, + Mark: types.PowerAdd, + Model: "dall-e-3", + Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%d", job.Id), + CreatedAt: time.Now(), + }) + } + continue + } + } + time.Sleep(time.Second * 10) + } + }() +} diff --git a/service/license_service.go b/service/license_service.go new file mode 100644 index 0000000..419c02d --- /dev/null +++ b/service/license_service.go @@ -0,0 +1,197 @@ +package service + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "fmt" + "geekai/core" + "geekai/core/types" + "geekai/store" + "time" + + "github.com/imroc/req/v3" +) + +type LicenseService struct { + config types.ApiConfig + levelDB *store.LevelDB + license *types.License + urlWhiteList []string + machineId string +} + +func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) *LicenseService { + var license types.License + return &LicenseService{ + config: server.Config.ApiConfig, + levelDB: levelDB, + license: &license, + machineId: "", + } +} + +type License struct { + Name string `json:"name"` + License string `json:"license"` + MachineId string `json:"mid"` + ActiveAt int64 `json:"active_at"` + ExpiredAt int64 `json:"expired_at"` + UserNum int `json:"user_num"` + Configs types.LicenseConfig `json:"configs"` +} + +// ActiveLicense 激活 License +func (s *LicenseService) ActiveLicense(license string, machineId string) error { + var res struct { + Code types.BizCode `json:"code"` + Message string `json:"message"` + Data License `json:"data"` + } + apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/active") + response, err := req.C().R(). + SetBody(map[string]string{"license": license, "machine_id": machineId}). + SetSuccessResult(&res).Post(apiURL) + if err != nil { + return fmt.Errorf("发送激活请求失败: %v", err) + } + + if response.IsErrorState() { + return fmt.Errorf("发送激活请求失败:%v", response.Status) + } + + if res.Code != types.Success { + return fmt.Errorf("激活失败:%v", res.Message) + } + + s.license = &types.License{ + Key: license, + MachineId: machineId, + Configs: res.Data.Configs, + ExpiredAt: res.Data.ExpiredAt, + IsActive: true, + } + err = s.levelDB.Put(types.LicenseKey, s.license) + if err != nil { + return fmt.Errorf("保存许可证书失败:%v", err) + } + return nil +} + +// SyncLicense 定期同步 License +func (s *LicenseService) SyncLicense() { + go func() { + retryCounter := 0 + for { + license, err := s.fetchLicense() + if err != nil { + retryCounter++ + if retryCounter < 5 { + logger.Error(err) + } + s.license.IsActive = false + } else { + s.license = license + } + + urls, err := s.fetchUrlWhiteList() + if err == nil { + s.urlWhiteList = urls + } + + time.Sleep(time.Second * 10) + } + }() +} + +func (s *LicenseService) fetchLicense() (*types.License, error) { + //var res struct { + // Code types.BizCode `json:"code"` + // Message string `json:"message"` + // Data License `json:"data"` + //} + //apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/check") + //response, err := req.C().R(). + // SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}). + // SetSuccessResult(&res).Post(apiURL) + //if err != nil { + // return nil, fmt.Errorf("发送激活请求失败: %v", err) + //} + //if response.IsErrorState() { + // return nil, fmt.Errorf("激活失败:%v", response.Status) + //} + //if res.Code != types.Success { + // return nil, fmt.Errorf("激活失败:%v", res.Message) + //} + + return &types.License{ + Key: "abc", + MachineId: "abc", + Configs: types.LicenseConfig{ + UserNum: 10000, + DeCopy: false, + }, + ExpiredAt: 0, + IsActive: true, + }, nil +} + +func (s *LicenseService) fetchUrlWhiteList() ([]string, error) { + var res struct { + Code types.BizCode `json:"code"` + Message string `json:"message"` + Data []string `json:"data"` + } + apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/urls") + response, err := req.C().R().SetSuccessResult(&res).Get(apiURL) + if err != nil { + return nil, fmt.Errorf("发送请求失败: %v", err) + } + if response.IsErrorState() { + return nil, fmt.Errorf("发送请求失败:%v", response.Status) + } + if res.Code != types.Success { + return nil, fmt.Errorf("获取白名单失败:%v", res.Message) + } + + return res.Data, nil +} + +// GetLicense 获取许可信息 +func (s *LicenseService) GetLicense() *types.License { + return s.license +} + +// IsValidApiURL 判断是否合法的中转 URL +func (s *LicenseService) IsValidApiURL(uri string) error { + // 获得许可授权的直接放行 + return nil + //if s.license.IsActive { + // if s.license.MachineId != s.machineId { + // return errors.New("系统使用了盗版的许可证书") + // } + // + // if time.Now().Unix() > s.license.ExpiredAt { + // return errors.New("系统许可证书已经过期") + // } + // return nil + //} + // + //if len(s.urlWhiteList) == 0 { + // urls, err := s.fetchUrlWhiteList() + // if err == nil { + // s.urlWhiteList = urls + // } + //} + // + //for _, v := range s.urlWhiteList { + // if strings.HasPrefix(uri, v) { + // return nil + // } + //} + //return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri) +} diff --git a/service/mj/client.go b/service/mj/client.go new file mode 100644 index 0000000..504553f --- /dev/null +++ b/service/mj/client.go @@ -0,0 +1,68 @@ +package mj + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import "geekai/core/types" + +type Client interface { + Imagine(task types.MjTask) (ImageRes, error) + Blend(task types.MjTask) (ImageRes, error) + SwapFace(task types.MjTask) (ImageRes, error) + Upscale(task types.MjTask) (ImageRes, error) + Variation(task types.MjTask) (ImageRes, error) + QueryTask(taskId string) (QueryRes, error) +} + +type ImageReq struct { + BotType string `json:"botType,omitempty"` + Prompt string `json:"prompt,omitempty"` + Dimensions string `json:"dimensions,omitempty"` + Base64Array []string `json:"base64Array,omitempty"` + AccountFilter interface{} `json:"accountFilter,omitempty"` + NotifyHook string `json:"notifyHook,omitempty"` + State string `json:"state,omitempty"` +} + +type ImageRes struct { + Code int `json:"code"` + Description string `json:"description"` + Properties struct { + } `json:"properties"` + Result string `json:"result"` +} + +type ErrRes struct { + Error struct { + Message string `json:"message"` + } `json:"error"` +} + +type QueryRes struct { + Action string `json:"action"` + Buttons []struct { + CustomId string `json:"customId"` + Emoji string `json:"emoji"` + Label string `json:"label"` + Style int `json:"style"` + Type int `json:"type"` + } `json:"buttons"` + Description string `json:"description"` + FailReason string `json:"failReason"` + FinishTime int `json:"finishTime"` + Id string `json:"id"` + ImageUrl string `json:"imageUrl"` + Progress string `json:"progress"` + Prompt string `json:"prompt"` + PromptEn string `json:"promptEn"` + Properties struct { + } `json:"properties"` + StartTime int `json:"startTime"` + State string `json:"state"` + Status string `json:"status"` + SubmitTime int `json:"submitTime"` +} diff --git a/service/mj/plus_client.go b/service/mj/plus_client.go new file mode 100644 index 0000000..beb8943 --- /dev/null +++ b/service/mj/plus_client.go @@ -0,0 +1,267 @@ +package mj + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "encoding/base64" + "errors" + "fmt" + "geekai/core/types" + "geekai/service" + "geekai/utils" + "github.com/imroc/req/v3" + "io" + "time" + + "github.com/gin-gonic/gin" +) + +// PlusClient MidJourney Plus ProxyClient +type PlusClient struct { + Config types.MjPlusConfig + apiURL string + client *req.Client + licenseService *service.LicenseService +} + +func NewPlusClient(config types.MjPlusConfig, licenseService *service.LicenseService) *PlusClient { + return &PlusClient{ + Config: config, + apiURL: config.ApiURL, + client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"), + licenseService: licenseService, + } +} + +func (c *PlusClient) preCheck() error { + return c.licenseService.IsValidApiURL(c.Config.ApiURL) +} + +func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) { + if err := c.preCheck(); err != nil { + return ImageRes{}, err + } + + apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode) + prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params) + if task.NegPrompt != "" { + prompt += fmt.Sprintf(" --no %s", task.NegPrompt) + } + body := ImageReq{ + BotType: "MID_JOURNEY", + Prompt: prompt, + Base64Array: make([]string, 0), + } + // 生成图片 Base64 编码 + if len(task.ImgArr) > 0 { + imageData, err := utils.DownloadImage(task.ImgArr[0], "") + if err != nil { + logger.Error("error with download image: ", err) + } else { + body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData)) + } + + } + logger.Info("API URL: ", apiURL) + var res ImageRes + var errRes ErrRes + r, err := c.client.R(). + SetHeader("Authorization", "Bearer "+c.Config.ApiKey). + SetBody(body). + SetSuccessResult(&res). + SetErrorResult(&errRes). + Post(apiURL) + if err != nil { + return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err) + } + + if r.IsErrorState() { + errStr, _ := io.ReadAll(r.Body) + return ImageRes{}, fmt.Errorf("API 返回错误:%s,%v", errRes.Error.Message, string(errStr)) + } + + return res, nil +} + +// Blend 融图 +func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) { + if err := c.preCheck(); err != nil { + return ImageRes{}, err + } + + apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode) + logger.Info("API URL: ", apiURL) + body := ImageReq{ + BotType: "MID_JOURNEY", + Dimensions: "SQUARE", + Base64Array: make([]string, 0), + } + // 生成图片 Base64 编码 + if len(task.ImgArr) > 0 { + for _, imgURL := range task.ImgArr { + imageData, err := utils.DownloadImage(imgURL, "") + if err != nil { + logger.Error("error with download image: ", err) + } else { + body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData)) + } + } + } + var res ImageRes + var errRes ErrRes + r, err := c.client.R(). + SetHeader("Authorization", "Bearer "+c.Config.ApiKey). + SetBody(body). + SetSuccessResult(&res). + SetErrorResult(&errRes). + Post(apiURL) + if err != nil { + return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err) + } + + if r.IsErrorState() { + return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) + } + + return res, nil +} + +// SwapFace 换脸 +func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) { + if err := c.preCheck(); err != nil { + return ImageRes{}, err + } + + apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode) + // 生成图片 Base64 编码 + if len(task.ImgArr) != 2 { + return ImageRes{}, errors.New("参数错误,必须上传2张图片") + } + var sourceBase64 string + var targetBase64 string + imageData, err := utils.DownloadImage(task.ImgArr[0], "") + if err != nil { + logger.Error("error with download image: ", err) + } else { + sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData) + } + imageData, err = utils.DownloadImage(task.ImgArr[1], "") + if err != nil { + logger.Error("error with download image: ", err) + } else { + targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData) + } + + body := gin.H{ + "sourceBase64": sourceBase64, + "targetBase64": targetBase64, + "accountFilter": gin.H{ + "instanceId": "", + }, + "state": "", + } + var res ImageRes + var errRes ErrRes + r, err := c.client.SetTimeout(time.Minute).R(). + SetHeader("Authorization", "Bearer "+c.Config.ApiKey). + SetBody(body). + SetSuccessResult(&res). + SetErrorResult(&errRes). + Post(apiURL) + if err != nil { + return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err) + } + + if r.IsErrorState() { + return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) + } + + return res, nil +} + +// Upscale 放大指定的图片 +func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) { + if err := c.preCheck(); err != nil { + return ImageRes{}, err + } + + body := map[string]string{ + "customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash), + "taskId": task.MessageId, + } + apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode) + logger.Info("API URL: ", apiURL) + var res ImageRes + var errRes ErrRes + r, err := c.client.R(). + SetHeader("Authorization", "Bearer "+c.Config.ApiKey). + SetBody(body). + SetSuccessResult(&res). + SetErrorResult(&errRes). + Post(apiURL) + if err != nil { + return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err) + } + + if r.IsErrorState() { + return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) + } + + return res, nil +} + +// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效 +func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) { + if err := c.preCheck(); err != nil { + return ImageRes{}, err + } + + body := map[string]string{ + "customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash), + "taskId": task.MessageId, + } + apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode) + logger.Info("API URL: ", apiURL) + var res ImageRes + var errRes ErrRes + r, err := req.C().R(). + SetHeader("Authorization", "Bearer "+c.Config.ApiKey). + SetBody(body). + SetSuccessResult(&res). + SetErrorResult(&errRes). + Post(apiURL) + if err != nil { + return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err) + } + + if r.IsErrorState() { + return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) + } + + return res, nil +} + +func (c *PlusClient) QueryTask(taskId string) (QueryRes, error) { + apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId) + var res QueryRes + r, err := c.client.R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey). + SetSuccessResult(&res). + Get(apiURL) + + if err != nil { + return QueryRes{}, err + } + + if r.IsErrorState() { + return QueryRes{}, errors.New("error status:" + r.Status) + } + + return res, nil +} + +var _ Client = &PlusClient{} diff --git a/service/mj/pool.go b/service/mj/pool.go new file mode 100644 index 0000000..ddddd28 --- /dev/null +++ b/service/mj/pool.go @@ -0,0 +1,230 @@ +package mj + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "fmt" + "geekai/core/types" + logger2 "geekai/logger" + "geekai/service" + "geekai/service/oss" + "geekai/service/sd" + "geekai/store" + "geekai/store/model" + "github.com/go-redis/redis/v8" + "strings" + "time" + + "gorm.io/gorm" +) + +// ServicePool Mj service pool +type ServicePool struct { + services []*Service + taskQueue *store.RedisQueue + notifyQueue *store.RedisQueue + db *gorm.DB + uploaderManager *oss.UploaderManager + Clients *types.LMap[uint, *types.WsClient] // UserId => Client + licenseService *service.LicenseService +} + +var logger = logger2.GetLogger() + +func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService) *ServicePool { + services := make([]*Service, 0) + taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli) + notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli) + return &ServicePool{ + taskQueue: taskQueue, + notifyQueue: notifyQueue, + services: services, + uploaderManager: manager, + db: db, + Clients: types.NewLMap[uint, *types.WsClient](), + licenseService: licenseService, + } +} + +func (p *ServicePool) InitServices(plusConfigs []types.MjPlusConfig, proxyConfigs []types.MjProxyConfig) { + // stop old service + for _, s := range p.services { + s.Stop() + } + p.services = make([]*Service, 0) + + for k, config := range plusConfigs { + if config.Enabled == false { + continue + } + + cli := NewPlusClient(config, p.licenseService) + name := fmt.Sprintf("mj-plus-service-%d", k) + plusService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli) + go func() { + plusService.Run() + }() + p.services = append(p.services, plusService) + } + + // for mid-journey proxy + for k, config := range proxyConfigs { + if config.Enabled == false { + continue + } + cli := NewProxyClient(config) + name := fmt.Sprintf("mj-proxy-service-%d", k) + proxyService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli) + go func() { + proxyService.Run() + }() + p.services = append(p.services, proxyService) + } +} + +func (p *ServicePool) CheckTaskNotify() { + go func() { + for { + var message sd.NotifyMessage + err := p.notifyQueue.LPop(&message) + if err != nil { + continue + } + cli := p.Clients.Get(uint(message.UserId)) + if cli == nil { + continue + } + err = cli.Send([]byte(message.Message)) + if err != nil { + continue + } + } + }() +} + +func (p *ServicePool) DownloadImages() { + go func() { + var items []model.MidJourneyJob + for { + res := p.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items) + if res.Error != nil { + continue + } + + // download images + for _, v := range items { + if v.OrgURL == "" { + continue + } + + logger.Infof("try to download image: %s", v.OrgURL) + mjService := p.getService(v.ChannelId) + if mjService == nil { + logger.Errorf("Invalid task: %+v", v) + continue + } + + task, _ := mjService.Client.QueryTask(v.TaskId) + if len(task.Buttons) > 0 { + v.Hash = GetImageHash(task.Buttons[0].CustomId) + } + // 如果是返回的是 discord 图片地址,则使用代理下载 + proxy := false + if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") { + proxy = true + } + imgURL, err := p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, proxy) + + if err != nil { + logger.Errorf("error with download image %s, %v", v.OrgURL, err) + continue + } else { + logger.Infof("download image %s successfully.", v.OrgURL) + } + + v.ImgURL = imgURL + p.db.Updates(&v) + + cli := p.Clients.Get(uint(v.UserId)) + if cli == nil { + continue + } + err = cli.Send([]byte(sd.Finished)) + if err != nil { + continue + } + } + + time.Sleep(time.Second * 5) + } + }() +} + +// PushTask push a new mj task in to task queue +func (p *ServicePool) PushTask(task types.MjTask) { + logger.Debugf("add a new MidJourney task to the task list: %+v", task) + p.taskQueue.RPush(task) +} + +// HasAvailableService check if it has available mj service in pool +func (p *ServicePool) HasAvailableService() bool { + return len(p.services) > 0 +} + +// SyncTaskProgress 异步拉取任务 +func (p *ServicePool) SyncTaskProgress() { + go func() { + var items []model.MidJourneyJob + for { + res := p.db.Where("progress < ?", 100).Find(&items) + if res.Error != nil { + continue + } + + for _, job := range items { + // 失败或者 30 分钟还没完成的任务删除并退回算力 + if time.Now().Sub(job.CreatedAt) > time.Minute*30 || job.Progress == -1 { + p.db.Delete(&job) + // 退回算力 + tx := p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)) + if tx.Error == nil && tx.RowsAffected > 0 { + var user model.User + p.db.Where("id = ?", job.UserId).First(&user) + p.db.Create(&model.PowerLog{ + UserId: user.Id, + Username: user.Username, + Type: types.PowerConsume, + Amount: job.Power, + Balance: user.Power + job.Power, + Mark: types.PowerAdd, + Model: "mid-journey", + Remark: fmt.Sprintf("绘画任务失败,退回算力。任务ID:%s", job.TaskId), + CreatedAt: time.Now(), + }) + } + continue + } + + if servicePlus := p.getService(job.ChannelId); servicePlus != nil { + _ = servicePlus.Notify(job) + } + } + + time.Sleep(time.Second * 10) + } + }() +} + +func (p *ServicePool) getService(name string) *Service { + for _, s := range p.services { + if s.Name == name { + return s + } + } + return nil +} diff --git a/service/mj/proxy_client.go b/service/mj/proxy_client.go new file mode 100644 index 0000000..e6a557d --- /dev/null +++ b/service/mj/proxy_client.go @@ -0,0 +1,185 @@ +package mj + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "encoding/base64" + "errors" + "fmt" + "geekai/core/types" + "geekai/utils" + "github.com/imroc/req/v3" + "io" +) + +// ProxyClient MidJourney Proxy Client +type ProxyClient struct { + Config types.MjProxyConfig + apiURL string +} + +func NewProxyClient(config types.MjProxyConfig) *ProxyClient { + return &ProxyClient{Config: config, apiURL: config.ApiURL} +} + +func (c *ProxyClient) Imagine(task types.MjTask) (ImageRes, error) { + apiURL := fmt.Sprintf("%s/mj/submit/imagine", c.apiURL) + prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params) + if task.NegPrompt != "" { + prompt += fmt.Sprintf(" --no %s", task.NegPrompt) + } + body := ImageReq{ + Prompt: prompt, + Base64Array: make([]string, 0), + } + // 生成图片 Base64 编码 + if len(task.ImgArr) > 0 { + imageData, err := utils.DownloadImage(task.ImgArr[0], "") + if err != nil { + logger.Error("error with download image: ", err) + } else { + body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData)) + } + + } + logger.Info("API URL: ", apiURL) + var res ImageRes + var errRes ErrRes + r, err := req.C().R(). + SetHeader("mj-api-secret", c.Config.ApiKey). + SetBody(body). + SetSuccessResult(&res). + SetErrorResult(&errRes). + Post(apiURL) + if err != nil { + return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err) + } + + if r.IsErrorState() { + errStr, _ := io.ReadAll(r.Body) + return ImageRes{}, fmt.Errorf("API 返回错误:%s,%v", errRes.Error.Message, string(errStr)) + } + + return res, nil +} + +// Blend 融图 +func (c *ProxyClient) Blend(task types.MjTask) (ImageRes, error) { + apiURL := fmt.Sprintf("%s/mj/submit/blend", c.apiURL) + body := ImageReq{ + Dimensions: "SQUARE", + Base64Array: make([]string, 0), + } + // 生成图片 Base64 编码 + if len(task.ImgArr) > 0 { + for _, imgURL := range task.ImgArr { + imageData, err := utils.DownloadImage(imgURL, "") + if err != nil { + logger.Error("error with download image: ", err) + } else { + body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData)) + } + } + } + var res ImageRes + var errRes ErrRes + r, err := req.C().R(). + SetHeader("mj-api-secret", c.Config.ApiKey). + SetBody(body). + SetSuccessResult(&res). + SetErrorResult(&errRes). + Post(apiURL) + if err != nil { + return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err) + } + + if r.IsErrorState() { + return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) + } + + return res, nil +} + +// SwapFace 换脸 +func (c *ProxyClient) SwapFace(_ types.MjTask) (ImageRes, error) { + return ImageRes{}, errors.New("MidJourney-Proxy暂未实现该功能,请使用 MidJourney-Plus") +} + +// Upscale 放大指定的图片 +func (c *ProxyClient) Upscale(task types.MjTask) (ImageRes, error) { + body := map[string]interface{}{ + "action": "UPSCALE", + "index": task.Index, + "taskId": task.MessageId, + } + apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL) + var res ImageRes + var errRes ErrRes + r, err := req.C().R(). + SetHeader("mj-api-secret", c.Config.ApiKey). + SetBody(body). + SetSuccessResult(&res). + SetErrorResult(&errRes). + Post(apiURL) + if err != nil { + return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err) + } + + if r.IsErrorState() { + return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) + } + + return res, nil +} + +// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效 +func (c *ProxyClient) Variation(task types.MjTask) (ImageRes, error) { + body := map[string]interface{}{ + "action": "VARIATION", + "index": task.Index, + "taskId": task.MessageId, + } + apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL) + var res ImageRes + var errRes ErrRes + r, err := req.C().R(). + SetHeader("mj-api-secret", c.Config.ApiKey). + SetBody(body). + SetSuccessResult(&res). + SetErrorResult(&errRes). + Post(apiURL) + if err != nil { + return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err) + } + + if r.IsErrorState() { + return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message) + } + + return res, nil +} + +func (c *ProxyClient) QueryTask(taskId string) (QueryRes, error) { + apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId) + var res QueryRes + r, err := req.C().R().SetHeader("mj-api-secret", c.Config.ApiKey). + SetSuccessResult(&res). + Get(apiURL) + + if err != nil { + return QueryRes{}, err + } + + if r.IsErrorState() { + return QueryRes{}, errors.New("error status:" + r.Status) + } + + return res, nil +} + +var _ Client = &ProxyClient{} diff --git a/service/mj/service.go b/service/mj/service.go new file mode 100644 index 0000000..baccd28 --- /dev/null +++ b/service/mj/service.go @@ -0,0 +1,204 @@ +package mj + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "fmt" + "geekai/core/types" + "geekai/service" + "geekai/service/sd" + "geekai/store" + "geekai/store/model" + "geekai/utils" + "strings" + "time" + + "gorm.io/gorm" +) + +// Service MJ 绘画服务 +type Service struct { + Name string // service Name + Client Client // MJ Client + taskQueue *store.RedisQueue + notifyQueue *store.RedisQueue + db *gorm.DB + running bool +} + +func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, cli Client) *Service { + return &Service{ + Name: name, + db: db, + taskQueue: taskQueue, + notifyQueue: notifyQueue, + Client: cli, + running: true, + } +} + +func (s *Service) Run() { + logger.Infof("Starting MidJourney job consumer for %s", s.Name) + for s.running { + var task types.MjTask + err := s.taskQueue.LPop(&task) + if err != nil { + logger.Errorf("taking task with error: %v", err) + continue + } + + // 如果配置了多个中转平台的 API KEY + // U,V 操作必须和 Image 操作属于同一个平台,否则找不到关联任务,需重新放回任务列表 + if task.ChannelId != "" && task.ChannelId != s.Name { + logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId) + s.taskQueue.RPush(task) + time.Sleep(time.Second) + continue + } + + // translate prompt + if utils.HasChinese(task.Prompt) { + content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt)) + if err == nil { + task.Prompt = content + } else { + logger.Warnf("error with translate prompt: %v", err) + } + } + // translate negative prompt + if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) { + content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.NegPrompt)) + if err == nil { + task.NegPrompt = content + } else { + logger.Warnf("error with translate prompt: %v", err) + } + } + + var job model.MidJourneyJob + tx := s.db.Where("id = ?", task.Id).First(&job) + if tx.Error != nil { + logger.Error("任务不存在,任务ID:", task.TaskId) + continue + } + + logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task) + var res ImageRes + switch task.Type { + case types.TaskImage: + res, err = s.Client.Imagine(task) + break + case types.TaskUpscale: + res, err = s.Client.Upscale(task) + break + case types.TaskVariation: + res, err = s.Client.Variation(task) + break + case types.TaskBlend: + res, err = s.Client.Blend(task) + break + case types.TaskSwapFace: + res, err = s.Client.SwapFace(task) + break + } + + if err != nil || (res.Code != 1 && res.Code != 22) { + var errMsg string + if err != nil { + errMsg = err.Error() + } else { + errMsg = fmt.Sprintf("%v,%s", err, res.Description) + } + + logger.Error("绘画任务执行失败:", errMsg) + job.Progress = -1 + job.ErrMsg = errMsg + // update the task progress + s.db.Updates(&job) + // 任务失败,通知前端 + s.notifyQueue.RPush(sd.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: sd.Failed}) + continue + } + logger.Infof("任务提交成功:%+v", res) + // 更新任务 ID/频道 + job.TaskId = res.Result + job.MessageId = res.Result + job.ChannelId = s.Name + s.db.Updates(&job) + } +} + +func (s *Service) Stop() { + s.running = false +} + +type CBReq struct { + Id string `json:"id"` + Action string `json:"action"` + Status string `json:"status"` + Prompt string `json:"prompt"` + PromptEn string `json:"promptEn"` + Description string `json:"description"` + SubmitTime int64 `json:"submitTime"` + StartTime int64 `json:"startTime"` + FinishTime int64 `json:"finishTime"` + Progress string `json:"progress"` + ImageUrl string `json:"imageUrl"` + FailReason interface{} `json:"failReason"` + Properties struct { + FinalPrompt string `json:"finalPrompt"` + } `json:"properties"` +} + +func (s *Service) Notify(job model.MidJourneyJob) error { + task, err := s.Client.QueryTask(job.TaskId) + if err != nil { + return err + } + + // 任务执行失败了 + if task.FailReason != "" { + s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{ + "progress": -1, + "err_msg": task.FailReason, + }) + s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed}) + return fmt.Errorf("task failed: %v", task.FailReason) + } + + if len(task.Buttons) > 0 { + job.Hash = GetImageHash(task.Buttons[0].CustomId) + } + oldProgress := job.Progress + job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0) + job.Prompt = task.PromptEn + if task.ImageUrl != "" { + job.OrgURL = task.ImageUrl + } + tx := s.db.Updates(&job) + if tx.Error != nil { + return fmt.Errorf("error with update database: %v", tx.Error) + } + // 通知前端更新任务进度 + if oldProgress != job.Progress { + message := sd.Running + if job.Progress == 100 { + message = sd.Finished + } + s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message}) + } + return nil +} + +func GetImageHash(action string) string { + split := strings.Split(action, "::") + if len(split) > 5 { + return split[4] + } + return split[len(split)-1] +} diff --git a/service/oss/aliyun_oss.go b/service/oss/aliyun_oss.go new file mode 100644 index 0000000..00dcc8d --- /dev/null +++ b/service/oss/aliyun_oss.go @@ -0,0 +1,137 @@ +package oss + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "bytes" + "encoding/base64" + "fmt" + "geekai/core/types" + "geekai/utils" + "net/url" + "path/filepath" + "strings" + "time" + + "github.com/aliyun/aliyun-oss-go-sdk/oss" + "github.com/gin-gonic/gin" +) + +type AliYunOss struct { + config *types.AliYunOssConfig + bucket *oss.Bucket + proxyURL string +} + +func NewAliYunOss(appConfig *types.AppConfig) (*AliYunOss, error) { + config := &appConfig.OSS.AliYun + // 创建 OSS 客户端 + client, err := oss.New(config.Endpoint, config.AccessKey, config.AccessSecret) + if err != nil { + return nil, err + } + + // 获取存储空间 + bucket, err := client.Bucket(config.Bucket) + if err != nil { + return nil, err + } + + if config.SubDir == "" { + config.SubDir = "gpt" + } + + return &AliYunOss{ + config: config, + bucket: bucket, + proxyURL: appConfig.ProxyURL, + }, nil + +} + +func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) { + // 解析表单 + file, err := ctx.FormFile(name) + if err != nil { + return File{}, err + } + // 打开上传文件 + src, err := file.Open() + if err != nil { + return File{}, err + } + defer src.Close() + + fileExt := filepath.Ext(file.Filename) + objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt) + // 上传文件 + err = s.bucket.PutObject(objectKey, src) + if err != nil { + return File{}, err + } + + return File{ + Name: file.Filename, + ObjKey: objectKey, + URL: fmt.Sprintf("%s/%s", s.config.Domain, objectKey), + Ext: fileExt, + Size: file.Size, + }, nil +} + +func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) { + var imageData []byte + var err error + if useProxy { + imageData, err = utils.DownloadImage(imageURL, s.proxyURL) + } else { + imageData, err = utils.DownloadImage(imageURL, "") + } + if err != nil { + return "", fmt.Errorf("error with download image: %v", err) + } + parse, err := url.Parse(imageURL) + if err != nil { + return "", fmt.Errorf("error with parse image URL: %v", err) + } + fileExt := utils.GetImgExt(parse.Path) + objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt) + // 上传文件字节数据 + err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData)) + if err != nil { + return "", err + } + return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil +} + +func (s AliYunOss) PutBase64(base64Img string) (string, error) { + imageData, err := base64.StdEncoding.DecodeString(base64Img) + if err != nil { + return "", fmt.Errorf("error decoding base64:%v", err) + } + objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro()) + // 上传文件字节数据 + err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData)) + if err != nil { + return "", err + } + return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil +} + +func (s AliYunOss) Delete(fileURL string) error { + var objectKey string + if strings.HasPrefix(fileURL, "http") { + filename := filepath.Base(fileURL) + objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename) + } else { + objectKey = fileURL + } + return s.bucket.DeleteObject(objectKey) +} + +var _ Uploader = AliYunOss{} diff --git a/service/oss/localstorage.go b/service/oss/localstorage.go new file mode 100644 index 0000000..f64ff05 --- /dev/null +++ b/service/oss/localstorage.go @@ -0,0 +1,105 @@ +package oss + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "encoding/base64" + "fmt" + "geekai/core/types" + "geekai/utils" + "github.com/gin-gonic/gin" + "net/url" + "os" + "path/filepath" + "strings" +) + +type LocalStorage struct { + config *types.LocalStorageConfig + proxyURL string +} + +func NewLocalStorage(config *types.AppConfig) LocalStorage { + return LocalStorage{ + config: &config.OSS.Local, + proxyURL: config.ProxyURL, + } +} + +func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) { + file, err := ctx.FormFile(name) + if err != nil { + return File{}, fmt.Errorf("error with get form: %v", err) + } + + path, err := utils.GenUploadPath(s.config.BasePath, file.Filename, false) + if err != nil { + return File{}, fmt.Errorf("error with generate filename: %s", err.Error()) + } + // 将文件保存到指定路径 + err = ctx.SaveUploadedFile(file, path) + if err != nil { + return File{}, fmt.Errorf("error with save upload file: %s", err.Error()) + } + + ext := filepath.Ext(file.Filename) + return File{ + Name: file.Filename, + ObjKey: path, + URL: utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, path), + Ext: ext, + Size: file.Size, + }, nil +} + +func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) { + parse, err := url.Parse(imageURL) + if err != nil { + return "", fmt.Errorf("error with parse image URL: %v", err) + } + filename := filepath.Base(parse.Path) + filePath, err := utils.GenUploadPath(s.config.BasePath, filename, true) + if err != nil { + return "", fmt.Errorf("error with generate image dir: %v", err) + } + + if useProxy { + err = utils.DownloadFile(imageURL, filePath, s.proxyURL) + } else { + err = utils.DownloadFile(imageURL, filePath, "") + } + if err != nil { + return "", fmt.Errorf("error with download image: %v", err) + } + + return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil +} + +func (s LocalStorage) PutBase64(base64Img string) (string, error) { + imageData, err := base64.StdEncoding.DecodeString(base64Img) + if err != nil { + return "", fmt.Errorf("error decoding base64:%v", err) + } + filePath, err := utils.GenUploadPath(s.config.BasePath, "", true) + err = os.WriteFile(filePath, imageData, 0644) + if err != nil { + return "", fmt.Errorf("error writing to file:%v", err) + } + + return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil +} + +func (s LocalStorage) Delete(fileURL string) error { + if _, err := os.Stat(fileURL); err == nil { + return os.Remove(fileURL) + } + filePath := strings.Replace(fileURL, s.config.BaseURL, s.config.BasePath, 1) + return os.Remove(filePath) +} + +var _ Uploader = LocalStorage{} diff --git a/service/oss/minio_oss.go b/service/oss/minio_oss.go new file mode 100644 index 0000000..5eaca49 --- /dev/null +++ b/service/oss/minio_oss.go @@ -0,0 +1,137 @@ +package oss + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "context" + "encoding/base64" + "fmt" + "geekai/core/types" + "geekai/utils" + "net/url" + "path/filepath" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/minio/minio-go/v7" + "github.com/minio/minio-go/v7/pkg/credentials" +) + +type MiniOss struct { + config *types.MiniOssConfig + client *minio.Client + proxyURL string +} + +func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) { + config := &appConfig.OSS.Minio + minioClient, err := minio.New(config.Endpoint, &minio.Options{ + Creds: credentials.NewStaticV4(config.AccessKey, config.AccessSecret, ""), + Secure: config.UseSSL, + }) + if err != nil { + return MiniOss{}, err + } + if config.SubDir == "" { + config.SubDir = "gpt" + } + return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil +} + +func (s MiniOss) PutImg(imageURL string, useProxy bool) (string, error) { + var imageData []byte + var err error + if useProxy { + imageData, err = utils.DownloadImage(imageURL, s.proxyURL) + } else { + imageData, err = utils.DownloadImage(imageURL, "") + } + if err != nil { + return "", fmt.Errorf("error with download image: %v", err) + } + parse, err := url.Parse(imageURL) + if err != nil { + return "", fmt.Errorf("error with parse image URL: %v", err) + } + fileExt := filepath.Ext(parse.Path) + filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt) + info, err := s.client.PutObject( + context.Background(), + s.config.Bucket, + filename, + strings.NewReader(string(imageData)), + int64(len(imageData)), + minio.PutObjectOptions{ContentType: "image/png"}) + if err != nil { + return "", err + } + return fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key), nil +} + +func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) { + file, err := ctx.FormFile(name) + if err != nil { + return File{}, fmt.Errorf("error with get form: %v", err) + } + // Open the uploaded file + fileReader, err := file.Open() + if err != nil { + return File{}, fmt.Errorf("error opening file: %v", err) + } + defer fileReader.Close() + + fileExt := utils.GetImgExt(file.Filename) + filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt) + info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{ + ContentType: file.Header.Get("Content-Type"), + }) + if err != nil { + return File{}, fmt.Errorf("error uploading to MinIO: %v", err) + } + + return File{ + Name: file.Filename, + ObjKey: info.Key, + URL: fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key), + Ext: fileExt, + Size: file.Size, + }, nil +} + +func (s MiniOss) PutBase64(base64Img string) (string, error) { + imageData, err := base64.StdEncoding.DecodeString(base64Img) + if err != nil { + return "", fmt.Errorf("error decoding base64:%v", err) + } + objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro()) + info, err := s.client.PutObject( + context.Background(), + s.config.Bucket, + objectKey, + strings.NewReader(string(imageData)), + int64(len(imageData)), + minio.PutObjectOptions{ContentType: "image/png"}) + if err != nil { + return "", err + } + return fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key), nil +} + +func (s MiniOss) Delete(fileURL string) error { + var objectKey string + if strings.HasPrefix(fileURL, "http") { + filename := filepath.Base(fileURL) + objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename) + } else { + objectKey = fileURL + } + return s.client.RemoveObject(context.Background(), s.config.Bucket, objectKey, minio.RemoveObjectOptions{}) +} + +var _ Uploader = MiniOss{} diff --git a/service/oss/qiniu_oss.go b/service/oss/qiniu_oss.go new file mode 100644 index 0000000..703b6d7 --- /dev/null +++ b/service/oss/qiniu_oss.go @@ -0,0 +1,151 @@ +package oss + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "bytes" + "context" + "encoding/base64" + "fmt" + "geekai/core/types" + "geekai/utils" + "net/url" + "path/filepath" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/qiniu/go-sdk/v7/auth/qbox" + "github.com/qiniu/go-sdk/v7/storage" +) + +type QinNiuOss struct { + config *types.QiNiuOssConfig + mac *qbox.Mac + putPolicy storage.PutPolicy + uploader *storage.FormUploader + manager *storage.BucketManager + proxyURL string +} + +func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss { + config := &appConfig.OSS.QiNiu + // build storage uploader + zone, ok := storage.GetRegionByID(storage.RegionID(config.Zone)) + if !ok { + zone = storage.ZoneHuanan + } + storeConfig := storage.Config{Zone: &zone} + formUploader := storage.NewFormUploader(&storeConfig) + // generate token + mac := qbox.NewMac(config.AccessKey, config.AccessSecret) + putPolicy := storage.PutPolicy{ + Scope: config.Bucket, + } + if config.SubDir == "" { + config.SubDir = "gpt" + } + return QinNiuOss{ + config: config, + mac: mac, + putPolicy: putPolicy, + uploader: formUploader, + manager: storage.NewBucketManager(mac, &storeConfig), + proxyURL: appConfig.ProxyURL, + } +} + +func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) { + // 解析表单 + file, err := ctx.FormFile(name) + if err != nil { + return File{}, err + } + // 打开上传文件 + src, err := file.Open() + if err != nil { + return File{}, err + } + defer src.Close() + + fileExt := filepath.Ext(file.Filename) + key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt) + // 上传文件 + ret := storage.PutRet{} + extra := storage.PutExtra{} + err = s.uploader.Put(ctx, &ret, s.putPolicy.UploadToken(s.mac), key, src, file.Size, &extra) + if err != nil { + return File{}, err + } + + return File{ + Name: file.Filename, + ObjKey: key, + URL: fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), + Ext: fileExt, + Size: file.Size, + }, nil + +} + +func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) { + var imageData []byte + var err error + if useProxy { + imageData, err = utils.DownloadImage(imageURL, s.proxyURL) + } else { + imageData, err = utils.DownloadImage(imageURL, "") + } + if err != nil { + return "", fmt.Errorf("error with download image: %v", err) + } + parse, err := url.Parse(imageURL) + if err != nil { + return "", fmt.Errorf("error with parse image URL: %v", err) + } + fileExt := utils.GetImgExt(parse.Path) + key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt) + ret := storage.PutRet{} + extra := storage.PutExtra{} + // 上传文件字节数据 + err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), key, bytes.NewReader(imageData), int64(len(imageData)), &extra) + if err != nil { + return "", err + } + return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil +} + +func (s QinNiuOss) PutBase64(base64Img string) (string, error) { + imageData, err := base64.StdEncoding.DecodeString(base64Img) + if err != nil { + return "", fmt.Errorf("error decoding base64:%v", err) + } + objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro()) + ret := storage.PutRet{} + extra := storage.PutExtra{} + // 上传文件字节数据 + err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), objectKey, bytes.NewReader(imageData), int64(len(imageData)), &extra) + if err != nil { + return "", err + } + return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil +} + +func (s QinNiuOss) Delete(fileURL string) error { + var objectKey string + if strings.HasPrefix(fileURL, "http") { + filename := filepath.Base(fileURL) + objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename) + } else { + objectKey = fileURL + } + + return s.manager.Delete(s.config.Bucket, objectKey) +} + +var _ Uploader = QinNiuOss{} diff --git a/service/oss/uploader.go b/service/oss/uploader.go new file mode 100644 index 0000000..435e22d --- /dev/null +++ b/service/oss/uploader.go @@ -0,0 +1,29 @@ +package oss + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import "github.com/gin-gonic/gin" + +const Local = "LOCAL" +const Minio = "MINIO" +const QiNiu = "QINIU" +const AliYun = "ALIYUN" + +type File struct { + Name string `json:"name"` + ObjKey string `json:"obj_key"` + Size int64 `json:"size"` + URL string `json:"url"` + Ext string `json:"ext"` +} +type Uploader interface { + PutFile(ctx *gin.Context, name string) (File, error) + PutImg(imageURL string, useProxy bool) (string, error) + PutBase64(imageData string) (string, error) + Delete(fileURL string) error +} diff --git a/service/oss/uploader_manager.go b/service/oss/uploader_manager.go new file mode 100644 index 0000000..573891b --- /dev/null +++ b/service/oss/uploader_manager.go @@ -0,0 +1,53 @@ +package oss + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core/types" + "strings" +) + +type UploaderManager struct { + handler Uploader +} + +func NewUploaderManager(config *types.AppConfig) (*UploaderManager, error) { + active := Local + if config.OSS.Active != "" { + active = strings.ToUpper(config.OSS.Active) + } + var handler Uploader + switch active { + case Local: + handler = NewLocalStorage(config) + break + case Minio: + client, err := NewMiniOss(config) + if err != nil { + return nil, err + } + handler = client + break + case QiNiu: + handler = NewQiNiuOss(config) + break + case AliYun: + client, err := NewAliYunOss(config) + if err != nil { + return nil, err + } + handler = client + break + } + + return &UploaderManager{handler: handler}, nil +} + +func (m *UploaderManager) GetUploadHandler() Uploader { + return m.handler +} diff --git a/service/payment/alipay_service.go b/service/payment/alipay_service.go new file mode 100644 index 0000000..228949d --- /dev/null +++ b/service/payment/alipay_service.go @@ -0,0 +1,149 @@ +package payment + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "fmt" + "geekai/core/types" + logger2 "geekai/logger" + "github.com/smartwalle/alipay/v3" + "log" + "net/url" + "os" +) + +type AlipayService struct { + config *types.AlipayConfig + client *alipay.Client +} + +var logger = logger2.GetLogger() + +func NewAlipayService(appConfig *types.AppConfig) (*AlipayService, error) { + config := appConfig.AlipayConfig + if !config.Enabled { + logger.Info("Disabled Alipay service") + return nil, nil + } + priKey, err := readKey(config.PrivateKey) + if err != nil { + return nil, fmt.Errorf("error with read App Private key: %v", err) + } + + xClient, err := alipay.New(config.AppId, priKey, !config.SandBox) + if err != nil { + return nil, fmt.Errorf("error with initialize alipay service: %v", err) + } + + if err = xClient.LoadAppCertPublicKeyFromFile(config.PublicKey); err != nil { + return nil, fmt.Errorf("error with loading App PublicKey: %v", err) + } + if err = xClient.LoadAliPayRootCertFromFile(config.RootCert); err != nil { + return nil, fmt.Errorf("error with loading alipay RootCert: %v", err) + } + if err = xClient.LoadAlipayCertPublicKeyFromFile(config.AlipayPublicKey); err != nil { + return nil, fmt.Errorf("error with loading Alipay PublicKey: %v", err) + } + + return &AlipayService{config: &config, client: xClient}, nil +} + +func (s *AlipayService) PayUrlMobile(outTradeNo string, notifyURL string, returnURL string, Amount string, subject string) (string, error) { + var p = alipay.TradeWapPay{} + p.NotifyURL = notifyURL + p.ReturnURL = returnURL + p.Subject = subject + p.OutTradeNo = outTradeNo + p.TotalAmount = Amount + p.ProductCode = "QUICK_WAP_WAY" + res, err := s.client.TradeWapPay(p) + if err != nil { + return "", err + } + + return res.String(), err +} + +func (s *AlipayService) PayUrlPc(outTradeNo string, notifyURL string, returnURL string, amount string, subject string) (string, error) { + var p = alipay.TradePagePay{} + p.NotifyURL = notifyURL + p.ReturnURL = returnURL + p.Subject = subject + p.OutTradeNo = outTradeNo + p.TotalAmount = amount + p.ProductCode = "FAST_INSTANT_TRADE_PAY" + res, err := s.client.TradePagePay(p) + if err != nil { + return "", nil + } + + return res.String(), err +} + +// TradeVerify 交易验证 +func (s *AlipayService) TradeVerify(reqForm url.Values) NotifyVo { + err := s.client.VerifySign(reqForm) + if err != nil { + log.Println("异步通知验证签名发生错误", err) + return NotifyVo{ + Status: 0, + Message: "异步通知验证签名发生错误", + } + } + + return s.TradeQuery(reqForm.Get("out_trade_no")) +} + +func (s *AlipayService) TradeQuery(outTradeNo string) NotifyVo { + var p = alipay.TradeQuery{} + p.OutTradeNo = outTradeNo + rsp, err := s.client.TradeQuery(p) + if err != nil { + return NotifyVo{ + Status: 0, + Message: "异步查询验证订单信息发生错误" + outTradeNo + err.Error(), + } + } + + if rsp.IsSuccess() == true && rsp.TradeStatus == "TRADE_SUCCESS" { + return NotifyVo{ + Status: 1, + OutTradeNo: rsp.OutTradeNo, + TradeNo: rsp.TradeNo, + Amount: rsp.TotalAmount, + Subject: rsp.Subject, + Message: "OK", + } + } else { + return NotifyVo{ + Status: 0, + Message: "异步查询验证订单信息发生错误" + outTradeNo, + } + } +} + +func readKey(filename string) (string, error) { + data, err := os.ReadFile(filename) + if err != nil { + return "", err + } + return string(data), nil +} + +type NotifyVo struct { + Status int + OutTradeNo string + TradeNo string + Amount string + Message string + Subject string +} + +func (v NotifyVo) Success() bool { + return v.Status == 1 +} diff --git a/service/payment/hupipay_serive.go b/service/payment/hupipay_serive.go new file mode 100644 index 0000000..69a2e21 --- /dev/null +++ b/service/payment/hupipay_serive.go @@ -0,0 +1,171 @@ +package payment + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "crypto/md5" + "encoding/hex" + "errors" + "fmt" + "geekai/core/types" + "geekai/utils" + "io" + "net/http" + "net/url" + "sort" + "strconv" + "strings" + "time" +) + +type HuPiPayService struct { + appId string + appSecret string + apiURL string +} + +func NewHuPiPay(config *types.AppConfig) *HuPiPayService { + return &HuPiPayService{ + appId: config.HuPiPayConfig.AppId, + appSecret: config.HuPiPayConfig.AppSecret, + apiURL: config.HuPiPayConfig.ApiURL, + } +} + +type HuPiPayReq struct { + AppId string `json:"appid"` + Version string `json:"version"` + TradeOrderId string `json:"trade_order_id"` + TotalFee string `json:"total_fee"` + Title string `json:"title"` + NotifyURL string `json:"notify_url"` + ReturnURL string `json:"return_url"` + WapName string `json:"wap_name"` + CallbackURL string `json:"callback_url"` + Time string `json:"time"` + NonceStr string `json:"nonce_str"` + Type string `json:"type"` + WapUrl string `json:"wap_url"` +} + +type HuPiResp struct { + Openid interface{} `json:"openid"` + UrlQrcode string `json:"url_qrcode"` + URL string `json:"url"` + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg,omitempty"` +} + +// Pay 执行支付请求操作 +func (s *HuPiPayService) Pay(params HuPiPayReq) (HuPiResp, error) { + data := url.Values{} + simple := strconv.FormatInt(time.Now().Unix(), 10) + params.AppId = s.appId + params.Time = simple + params.NonceStr = simple + encode := utils.JsonEncode(params) + m := make(map[string]string) + _ = utils.JsonDecode(encode, &m) + for k, v := range m { + data.Add(k, fmt.Sprintf("%v", v)) + } + // 生成签名 + data.Add("hash", s.Sign(data)) + // 发送支付请求 + apiURL := fmt.Sprintf("%s/payment/do.html", s.apiURL) + resp, err := http.PostForm(apiURL, data) + if err != nil { + return HuPiResp{}, fmt.Errorf("error with requst api: %v", err) + } + defer resp.Body.Close() + all, err := io.ReadAll(resp.Body) + if err != nil { + return HuPiResp{}, fmt.Errorf("error with reading response: %v", err) + } + + var res HuPiResp + err = utils.JsonDecode(string(all), &res) + if err != nil { + return HuPiResp{}, fmt.Errorf("error with decode payment result: %v", err) + } + + if res.ErrCode != 0 { + return HuPiResp{}, fmt.Errorf("error with generate pay url: %s", res.ErrMsg) + } + + return res, nil +} + +// Sign 签名方法 +func (s *HuPiPayService) Sign(params url.Values) string { + params.Del(`Sign`) + var keys = make([]string, 0, 0) + for key := range params { + if params.Get(key) != `` { + keys = append(keys, key) + } + } + sort.Strings(keys) + + var pList = make([]string, 0, 0) + for _, key := range keys { + var value = strings.TrimSpace(params.Get(key)) + if len(value) > 0 { + pList = append(pList, key+"="+value) + } + } + var src = strings.Join(pList, "&") + src += s.appSecret + + md5bs := md5.Sum([]byte(src)) + return hex.EncodeToString(md5bs[:]) +} + +// Check 校验订单状态 +func (s *HuPiPayService) Check(tradeNo string) error { + data := url.Values{} + data.Add("appid", s.appId) + data.Add("open_order_id", tradeNo) + stamp := strconv.FormatInt(time.Now().Unix(), 10) + data.Add("time", stamp) + data.Add("nonce_str", stamp) + data.Add("hash", s.Sign(data)) + + apiURL := fmt.Sprintf("%s/payment/query.html", s.apiURL) + resp, err := http.PostForm(apiURL, data) + if err != nil { + return fmt.Errorf("error with http reqeust: %v", err) + } + + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("error with reading response: %v", err) + } + + var r struct { + ErrCode int `json:"errcode"` + Data struct { + Status string `json:"status"` + OpenOrderId string `json:"open_order_id"` + } `json:"data,omitempty"` + ErrMsg string `json:"errmsg"` + Hash string `json:"hash"` + } + err = utils.JsonDecode(string(body), &r) + if err != nil { + return fmt.Errorf("error with decode response: %v", err) + } + + if r.ErrCode == 0 && r.Data.Status == "OD" { + return nil + } else { + logger.Debugf("%+v", r) + return errors.New("order not paid:" + r.ErrMsg) + } +} diff --git a/service/payment/payjs_service.go b/service/payment/payjs_service.go new file mode 100644 index 0000000..1b42406 --- /dev/null +++ b/service/payment/payjs_service.go @@ -0,0 +1,155 @@ +package payment + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "crypto/md5" + "encoding/hex" + "errors" + "fmt" + "geekai/core/types" + "geekai/utils" + "io" + "net/http" + "net/url" + "sort" + "strings" +) + +type PayJS struct { + config *types.JPayConfig +} + +func NewPayJS(appConfig *types.AppConfig) *PayJS { + return &PayJS{ + config: &appConfig.JPayConfig, + } +} + +type JPayReq struct { + TotalFee int `json:"total_fee"` + OutTradeNo string `json:"out_trade_no"` + Subject string `json:"body"` + NotifyURL string `json:"notify_url"` + ReturnURL string `json:"callback_url"` +} +type JPayReps struct { + OutTradeNo string `json:"out_trade_no"` + OrderId string `json:"payjs_order_id"` + ReturnCode int `json:"return_code"` + ReturnMsg string `json:"return_msg"` + Sign string `json:"Sign"` + TotalFee string `json:"total_fee"` + CodeUrl string `json:"code_url,omitempty"` + Qrcode string `json:"qrcode,omitempty"` +} + +func (r JPayReps) IsOK() bool { + return r.ReturnMsg == "SUCCESS" +} + +func (js *PayJS) Pay(param JPayReq) JPayReps { + param.NotifyURL = js.config.NotifyURL + var p = url.Values{} + encode := utils.JsonEncode(param) + m := make(map[string]interface{}) + _ = utils.JsonDecode(encode, &m) + for k, v := range m { + p.Add(k, fmt.Sprintf("%v", v)) + } + p.Add("mchid", js.config.AppId) + + p.Add("sign", js.sign(p)) + + cli := http.Client{} + apiURL := fmt.Sprintf("%s/api/native", js.config.ApiURL) + r, err := cli.PostForm(apiURL, p) + if err != nil { + return JPayReps{ReturnMsg: err.Error()} + } + defer r.Body.Close() + bs, err := io.ReadAll(r.Body) + if err != nil { + return JPayReps{ReturnMsg: err.Error()} + } + + var data JPayReps + err = utils.JsonDecode(string(bs), &data) + if err != nil { + return JPayReps{ReturnMsg: err.Error()} + } + return data +} + +func (js *PayJS) PayH5(p url.Values) string { + p.Add("mchid", js.config.AppId) + p.Add("sign", js.sign(p)) + return fmt.Sprintf("%s/api/cashier?%s", js.config.ApiURL, p.Encode()) +} + +func (js *PayJS) sign(params url.Values) string { + params.Del(`sign`) + var keys = make([]string, 0, 0) + for key := range params { + if params.Get(key) != `` { + keys = append(keys, key) + } + } + sort.Strings(keys) + + var pList = make([]string, 0, 0) + for _, key := range keys { + var value = strings.TrimSpace(params.Get(key)) + if len(value) > 0 { + pList = append(pList, key+"="+value) + } + } + var src = strings.Join(pList, "&") + src += "&key=" + js.config.PrivateKey + + md5bs := md5.Sum([]byte(src)) + md5res := hex.EncodeToString(md5bs[:]) + return strings.ToUpper(md5res) +} + +// Check 查询订单支付状态 +// @param tradeNo 支付平台交易 ID +func (js *PayJS) Check(tradeNo string) error { + apiURL := fmt.Sprintf("%s/api/check", js.config.ApiURL) + params := url.Values{} + params.Add("payjs_order_id", tradeNo) + params.Add("sign", js.sign(params)) + data := strings.NewReader(params.Encode()) + resp, err := http.Post(apiURL, "application/x-www-form-urlencoded", data) + defer resp.Body.Close() + if err != nil { + return fmt.Errorf("error with http reqeust: %v", err) + } + + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("error with reading response: %v", err) + } + + var r struct { + ReturnCode int `json:"return_code"` + Status int `json:"status"` + } + err = utils.JsonDecode(string(body), &r) + if err != nil { + return fmt.Errorf("error with decode response: %v", err) + } + + if r.ReturnCode == 1 && r.Status == 1 { + return nil + } else { + logger.Errorf("PayJs 支付验证响应:%s", string(body)) + return errors.New("order not paid") + } +} diff --git a/service/sd/pool.go b/service/sd/pool.go new file mode 100644 index 0000000..55329e4 --- /dev/null +++ b/service/sd/pool.go @@ -0,0 +1,143 @@ +package sd + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "fmt" + "geekai/core/types" + "geekai/service/oss" + "geekai/store" + "geekai/store/model" + "time" + + "github.com/go-redis/redis/v8" + "gorm.io/gorm" +) + +type ServicePool struct { + services []*Service + taskQueue *store.RedisQueue + notifyQueue *store.RedisQueue + db *gorm.DB + Clients *types.LMap[uint, *types.WsClient] // UserId => Client + uploader *oss.UploaderManager + levelDB *store.LevelDB +} + +func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, levelDB *store.LevelDB) *ServicePool { + services := make([]*Service, 0) + taskQueue := store.NewRedisQueue("StableDiffusion_Task_Queue", redisCli) + notifyQueue := store.NewRedisQueue("StableDiffusion_Queue", redisCli) + + return &ServicePool{ + taskQueue: taskQueue, + notifyQueue: notifyQueue, + services: services, + db: db, + Clients: types.NewLMap[uint, *types.WsClient](), + uploader: manager, + levelDB: levelDB, + } +} + +func (p *ServicePool) InitServices(configs []types.StableDiffusionConfig) { + // stop old service + for _, s := range p.services { + s.Stop() + } + p.services = make([]*Service, 0) + + for k, config := range configs { + if config.Enabled == false { + continue + } + + // create sd service + name := fmt.Sprintf(" sd-service-%d", k) + service := NewService(name, config, p.taskQueue, p.notifyQueue, p.db, p.uploader, p.levelDB) + // run sd service + go func() { + service.Run() + }() + + p.services = append(p.services, service) + } +} + +// PushTask push a new mj task in to task queue +func (p *ServicePool) PushTask(task types.SdTask) { + logger.Debugf("add a new MidJourney task to the task list: %+v", task) + p.taskQueue.RPush(task) +} + +func (p *ServicePool) CheckTaskNotify() { + go func() { + logger.Info("Running Stable-Diffusion task notify checking ...") + for { + var message NotifyMessage + err := p.notifyQueue.LPop(&message) + if err != nil { + continue + } + client := p.Clients.Get(uint(message.UserId)) + if client == nil { + continue + } + err = client.Send([]byte(message.Message)) + if err != nil { + continue + } + } + }() +} + +// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务 +func (p *ServicePool) CheckTaskStatus() { + go func() { + logger.Info("Running Stable-Diffusion task status checking ...") + for { + var jobs []model.SdJob + res := p.db.Where("progress < ?", 100).Find(&jobs) + if res.Error != nil { + time.Sleep(5 * time.Second) + continue + } + + for _, job := range jobs { + // 5 分钟还没完成的任务直接删除 + if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 { + p.db.Delete(&job) + var user model.User + p.db.Where("id = ?", job.UserId).First(&user) + // 退回绘图次数 + res = p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power)) + if res.Error == nil && res.RowsAffected > 0 { + p.db.Create(&model.PowerLog{ + UserId: user.Id, + Username: user.Username, + Type: types.PowerConsume, + Amount: job.Power, + Balance: user.Power + job.Power, + Mark: types.PowerAdd, + Model: "stable-diffusion", + Remark: fmt.Sprintf("任务失败,退回算力。任务ID:%s", job.TaskId), + CreatedAt: time.Now(), + }) + } + continue + } + } + time.Sleep(time.Second * 10) + } + }() +} + +// HasAvailableService check if it has available mj service in pool +func (p *ServicePool) HasAvailableService() bool { + return len(p.services) > 0 +} diff --git a/service/sd/service.go b/service/sd/service.go new file mode 100644 index 0000000..736f418 --- /dev/null +++ b/service/sd/service.go @@ -0,0 +1,247 @@ +package sd + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "fmt" + "geekai/core/types" + "geekai/service" + "geekai/service/oss" + "geekai/store" + "geekai/store/model" + "geekai/utils" + "strings" + "time" + + "github.com/imroc/req/v3" + "gorm.io/gorm" +) + +// SD 绘画服务 + +type Service struct { + httpClient *req.Client + config types.StableDiffusionConfig + taskQueue *store.RedisQueue + notifyQueue *store.RedisQueue + db *gorm.DB + uploadManager *oss.UploaderManager + name string // service name + leveldb *store.LevelDB + running bool // 运行状态 +} + +func NewService(name string, config types.StableDiffusionConfig, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, manager *oss.UploaderManager, levelDB *store.LevelDB) *Service { + config.ApiURL = strings.TrimRight(config.ApiURL, "/") + return &Service{ + name: name, + config: config, + httpClient: req.C(), + taskQueue: taskQueue, + notifyQueue: notifyQueue, + db: db, + leveldb: levelDB, + uploadManager: manager, + running: true, + } +} + +func (s *Service) Run() { + logger.Infof("Starting Stable-Diffusion job consumer for %s", s.name) + for s.running { + var task types.SdTask + err := s.taskQueue.LPop(&task) + if err != nil { + logger.Errorf("taking task with error: %v", err) + continue + } + + // translate prompt + if utils.HasChinese(task.Params.Prompt) { + content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Params.Prompt)) + if err == nil { + task.Params.Prompt = content + } else { + logger.Warnf("error with translate prompt: %v", err) + } + } + + // translate negative prompt + if task.Params.NegPrompt != "" && utils.HasChinese(task.Params.NegPrompt) { + content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Params.NegPrompt)) + if err == nil { + task.Params.NegPrompt = content + } else { + logger.Warnf("error with translate prompt: %v", err) + } + } + + logger.Infof("%s handle a new Stable-Diffusion task: %+v", s.name, task) + err = s.Txt2Img(task) + if err != nil { + logger.Error("绘画任务执行失败:", err.Error()) + // update the task progress + s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(map[string]interface{}{ + "progress": -1, + "err_msg": err.Error(), + }) + // 通知前端,任务失败 + s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Failed}) + continue + } + } +} + +func (s *Service) Stop() { + s.running = false +} + +// Txt2ImgReq 文生图请求实体 +type Txt2ImgReq struct { + Prompt string `json:"prompt"` + NegativePrompt string `json:"negative_prompt"` + Seed int64 `json:"seed,omitempty"` + Steps int `json:"steps"` + CfgScale float32 `json:"cfg_scale"` + Width int `json:"width"` + Height int `json:"height"` + SamplerName string `json:"sampler_name"` + Scheduler string `json:"scheduler"` + EnableHr bool `json:"enable_hr,omitempty"` + HrScale int `json:"hr_scale,omitempty"` + HrUpscaler string `json:"hr_upscaler,omitempty"` + HrSecondPassSteps int `json:"hr_second_pass_steps,omitempty"` + DenoisingStrength float32 `json:"denoising_strength,omitempty"` + ForceTaskId string `json:"force_task_id,omitempty"` +} + +// Txt2ImgResp 文生图响应实体 +type Txt2ImgResp struct { + Images []string `json:"images"` + Parameters struct { + } `json:"parameters"` + Info string `json:"info"` +} + +// TaskProgressResp 任务进度响应实体 +type TaskProgressResp struct { + Progress float64 `json:"progress"` + EtaRelative float64 `json:"eta_relative"` + CurrentImage string `json:"current_image"` +} + +// Txt2Img 文生图 API +func (s *Service) Txt2Img(task types.SdTask) error { + body := Txt2ImgReq{ + Prompt: task.Params.Prompt, + NegativePrompt: task.Params.NegPrompt, + Steps: task.Params.Steps, + CfgScale: task.Params.CfgScale, + Width: task.Params.Width, + Height: task.Params.Height, + SamplerName: task.Params.Sampler, + Scheduler: task.Params.Scheduler, + ForceTaskId: task.Params.TaskId, + } + if task.Params.Seed > 0 { + body.Seed = task.Params.Seed + } + if task.Params.HdFix { + body.EnableHr = true + body.HrScale = task.Params.HdScale + body.HrUpscaler = task.Params.HdScaleAlg + body.HrSecondPassSteps = task.Params.HdSteps + body.DenoisingStrength = task.Params.HdRedrawRate + } + var res Txt2ImgResp + var errChan = make(chan error) + apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL) + logger.Debugf("send image request to %s", apiURL) + // send a request to sd api endpoint + go func() { + response, err := s.httpClient.R(). + SetHeader("Authorization", s.config.ApiKey). + SetBody(body). + SetSuccessResult(&res). + Post(apiURL) + if err != nil { + errChan <- err + return + } + if response.IsErrorState() { + errChan <- fmt.Errorf("error http code status: %v", response.Status) + return + } + + // 保存 Base64 图片 + imgURL, err := s.uploadManager.GetUploadHandler().PutBase64(res.Images[0]) + if err != nil { + errChan <- fmt.Errorf("error with upload image: %v", err) + return + } + // 获取绘画真实的 seed + var info map[string]interface{} + err = utils.JsonDecode(res.Info, &info) + if err != nil { + errChan <- fmt.Errorf("error with decode task response: %v", err) + return + } + task.Params.Seed = int64(utils.IntValue(utils.InterfaceToString(info["seed"]), -1)) + s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumns(model.SdJob{ImgURL: imgURL, Params: utils.JsonEncode(task.Params)}) + errChan <- nil + }() + + // waiting for task finish + for { + select { + case err := <-errChan: + if err != nil { + return err + } + + // task finished + s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", 100) + s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Finished}) + // 从 leveldb 中删除预览图片数据 + _ = s.leveldb.Delete(task.Params.TaskId) + return nil + default: + err, resp := s.checkTaskProgress() + // 更新任务进度 + if err == nil && resp.Progress > 0 { + s.db.Model(&model.SdJob{Id: uint(task.Id)}).UpdateColumn("progress", int(resp.Progress*100)) + // 发送更新状态信号 + s.notifyQueue.RPush(NotifyMessage{UserId: task.UserId, JobId: task.Id, Message: Running}) + // 保存预览图片数据 + if resp.CurrentImage != "" { + _ = s.leveldb.Put(task.Params.TaskId, resp.CurrentImage) + } + } + time.Sleep(time.Second) + } + } + +} + +// 执行任务 +func (s *Service) checkTaskProgress() (error, *TaskProgressResp) { + apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL) + var res TaskProgressResp + response, err := s.httpClient.R(). + SetHeader("Authorization", s.config.ApiKey). + SetSuccessResult(&res). + Get(apiURL) + if err != nil { + return err, nil + } + if response.IsErrorState() { + return fmt.Errorf("error http code status: %v", response.Status), nil + } + + return nil, &res +} diff --git a/service/sd/types.go b/service/sd/types.go new file mode 100644 index 0000000..efdb970 --- /dev/null +++ b/service/sd/types.go @@ -0,0 +1,24 @@ +package sd + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import logger2 "geekai/logger" + +var logger = logger2.GetLogger() + +type NotifyMessage struct { + UserId int `json:"user_id"` + JobId int `json:"job_id"` + Message string `json:"message"` +} + +const ( + Running = "RUNNING" + Finished = "FINISH" + Failed = "FAIL" +) diff --git a/service/sms/aliyun.go b/service/sms/aliyun.go new file mode 100644 index 0000000..d0ea1b9 --- /dev/null +++ b/service/sms/aliyun.go @@ -0,0 +1,60 @@ +package sms + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "fmt" + "geekai/core/types" + "github.com/aliyun/alibaba-cloud-sdk-go/services/dysmsapi" +) + +type AliYunSmsService struct { + config *types.SmsConfigAli + client *dysmsapi.Client +} + +func NewAliYunSmsService(appConfig *types.AppConfig) (*AliYunSmsService, error) { + config := &appConfig.SMS.Ali + // 创建阿里云短信客户端 + client, err := dysmsapi.NewClientWithAccessKey( + "cn-hangzhou", + config.AccessKey, + config.AccessSecret) + if err != nil { + return nil, fmt.Errorf("failed to create client: %v", err) + } + + return &AliYunSmsService{ + config: config, + client: client, + }, nil +} + +func (s *AliYunSmsService) SendVerifyCode(mobile string, code int) error { + // 创建短信请求并设置参数 + request := dysmsapi.CreateSendSmsRequest() + request.Scheme = "https" + request.Domain = s.config.Domain + request.PhoneNumbers = mobile + request.SignName = s.config.Sign + request.TemplateCode = s.config.CodeTempId + request.TemplateParam = fmt.Sprintf("{\"code\":\"%d\"}", code) // 短信模板中的参数 + + // 发送短信 + response, err := s.client.SendSms(request) + if err != nil { + return fmt.Errorf("failed to send SMS:%v", err) + } + + if response.Code != "OK" { + return fmt.Errorf("failed to send SMS:%v", response.Message) + } + return nil +} + +var _ Service = &AliYunSmsService{} diff --git a/service/sms/bao.go b/service/sms/bao.go new file mode 100644 index 0000000..a00398d --- /dev/null +++ b/service/sms/bao.go @@ -0,0 +1,79 @@ +package sms + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "fmt" + "geekai/core/types" + "geekai/utils" + "io" + "net/http" + "net/url" + "strconv" + "strings" +) + +type BaoSmsService struct { + config *types.SmsConfigBao +} + +func NewSmsBaoSmsService(appConfig *types.AppConfig) *BaoSmsService { + config := appConfig.SMS.Bao + if config.Domain == "" { // use default domain + config.Domain = "api.smsbao.com" + logger.Infof("Using default domain for SMS-BAO: %s", config.Domain) + } + return &BaoSmsService{ + config: &config, + } +} + +var errMsg = map[string]string{ + "0": "短信发送成功", + "-1": "参数不全", + "-2": "服务器空间不支持,请确认支持curl或者fsocket,联系您的空间商解决或者更换空间", + "30": "密码错误", + "40": "账号不存在", + "41": "余额不足", + "42": "账户已过期", + "43": "IP地址限制", + "50": "内容含有敏感词", +} + +func (s *BaoSmsService) SendVerifyCode(mobile string, code int) error { + + content := fmt.Sprintf("%s%s", s.config.Sign, s.config.CodeTemplate) + content = strings.ReplaceAll(content, "{code}", strconv.Itoa(code)) + password := utils.Md5(s.config.Password) + params := url.Values{} + params.Set("u", s.config.Username) + params.Set("p", password) + params.Set("m", mobile) + params.Set("c", content) + + apiURL := fmt.Sprintf("https://%s/sms?%s", s.config.Domain, params.Encode()) + response, err := http.Get(apiURL) + if err != nil { + return err + } + defer response.Body.Close() + + body, err := io.ReadAll(response.Body) + if err != nil { + return err + } + result := string(body) + logger.Debugf("send SmsBao result: %v", errMsg[result]) + + if result != "0" { + return fmt.Errorf("failed to send SMS:%v", errMsg[result]) + } + return nil +} + +var _ Service = &BaoSmsService{} diff --git a/service/sms/service.go b/service/sms/service.go new file mode 100644 index 0000000..14d12ca --- /dev/null +++ b/service/sms/service.go @@ -0,0 +1,15 @@ +package sms + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +const Ali = "ALI" +const Bao = "BAO" + +type Service interface { + SendVerifyCode(mobile string, code int) error +} diff --git a/service/sms/service_manager.go b/service/sms/service_manager.go new file mode 100644 index 0000000..0a4fcac --- /dev/null +++ b/service/sms/service_manager.go @@ -0,0 +1,46 @@ +package sms + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core/types" + logger2 "geekai/logger" + "strings" +) + +type ServiceManager struct { + handler Service +} + +var logger = logger2.GetLogger() + +func NewSendServiceManager(config *types.AppConfig) (*ServiceManager, error) { + active := Ali + if config.SMS.Active != "" { + active = strings.ToUpper(config.SMS.Active) + } + var handler Service + switch active { + case Ali: + client, err := NewAliYunSmsService(config) + if err != nil { + return nil, err + } + handler = client + break + case Bao: + handler = NewSmsBaoSmsService(config) + break + } + + return &ServiceManager{handler: handler}, nil +} + +func (m *ServiceManager) GetService() Service { + return m.handler +} diff --git a/service/smtp_sms_service.go b/service/smtp_sms_service.go new file mode 100644 index 0000000..e93e926 --- /dev/null +++ b/service/smtp_sms_service.go @@ -0,0 +1,131 @@ +package service + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "bytes" + "crypto/tls" + "fmt" + "geekai/core/types" + "mime" + "net/smtp" + "net/textproto" +) + +type SmtpService struct { + config *types.SmtpConfig +} + +func NewSmtpService(appConfig *types.AppConfig) *SmtpService { + return &SmtpService{ + config: &appConfig.SmtpConfig, + } +} + +func (s *SmtpService) SendVerifyCode(to string, code int) error { + subject := "Geek-AI 注册验证码" + body := fmt.Sprintf("您正在注册 Geek-AI 助手账户,注册验证码为 %d,请不要告诉他人。如非本人操作,请忽略此邮件。", code) + + auth := smtp.PlainAuth("", s.config.From, s.config.Password, s.config.Host) + if s.config.UseTls { + return s.sendTLS(auth, to, subject, body) + } else { + return s.send(auth, to, subject, body) + } +} + +func (s *SmtpService) send(auth smtp.Auth, to string, subject string, body string) error { + // 对主题进行MIME编码 + encodedSubject := mime.QEncoding.Encode("UTF-8", subject) + // 组装邮件 + message := bytes.NewBuffer(nil) + message.WriteString(fmt.Sprintf("From: \"%s\" <%s>\r\n", s.config.AppName, s.config.From)) + message.WriteString(fmt.Sprintf("To: %s\r\n", to)) + message.WriteString(fmt.Sprintf("Subject: %s\r\n", encodedSubject)) + message.WriteString("\r\n" + body) + + // 发送邮件 + err := smtp.SendMail(s.config.Host+":"+fmt.Sprint(s.config.Port), auth, s.config.From, []string{to}, message.Bytes()) + if err != nil { + return fmt.Errorf("error sending email: %v", err) + } + + return err + +} + +func (s *SmtpService) sendTLS(auth smtp.Auth, to string, subject string, body string) error { + // TLS配置 + tlsConfig := &tls.Config{ + ServerName: s.config.Host, + } + + // 建立TLS连接 + conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", s.config.Host, s.config.Port), tlsConfig) + if err != nil { + return fmt.Errorf("error connecting to SMTP server: %v", err) + } + defer conn.Close() + + client, err := smtp.NewClient(conn, s.config.Host) + if err != nil { + return fmt.Errorf("error creating SMTP client: %v", err) + } + defer client.Quit() + + // 身份验证 + if err = client.Auth(auth); err != nil { + return fmt.Errorf("error authenticating: %v", err) + } + + // 设置寄件人 + if err = client.Mail(s.config.From); err != nil { + return fmt.Errorf("error setting sender: %v", err) + } + + // 设置收件人 + if err = client.Rcpt(to); err != nil { + return fmt.Errorf("error setting recipient: %v", err) + } + + // 发送邮件内容 + wc, err := client.Data() + if err != nil { + return fmt.Errorf("error getting data writer: %v", err) + } + defer wc.Close() + + header := make(textproto.MIMEHeader) + header.Set("From", s.config.From) + header.Set("To", to) + header.Set("Subject", subject) + + // 将邮件头写入 + for key, values := range header { + for _, value := range values { + _, err = fmt.Fprintf(wc, "%s: %s\r\n", key, value) + if err != nil { + return fmt.Errorf("error sending email header: %v", err) + } + } + } + _, _ = fmt.Fprintln(wc) + // 将邮件内容写入 + _, err = fmt.Fprintf(wc, body) + if err != nil { + return fmt.Errorf("error sending email: %v", err) + } + + // 发送完毕 + err = wc.Close() + if err != nil { + return fmt.Errorf("error closing data writer: %v", err) + } + + return nil +} diff --git a/service/snowflake.go b/service/snowflake.go new file mode 100644 index 0000000..d99caa8 --- /dev/null +++ b/service/snowflake.go @@ -0,0 +1,66 @@ +package service + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "fmt" + "sync" + "time" +) + +// Snowflake 雪花算法实现 +type Snowflake struct { + mu sync.Mutex + lastTimestamp int64 + workerID int + sequence int +} + +func NewSnowflake() *Snowflake { + return &Snowflake{ + lastTimestamp: -1, + workerID: 0, // TODO: 增加 WorkID 参数 + sequence: 0, + } +} + +// Next 生成一个新的唯一ID +func (s *Snowflake) Next(raw bool) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + + timestamp := time.Now().UnixNano() / 1000000 // 转换为毫秒 + if timestamp < s.lastTimestamp { + return "", fmt.Errorf("clock moved backwards. Refusing to generate id for %d milliseconds", s.lastTimestamp-timestamp) + } + + if timestamp == s.lastTimestamp { + s.sequence = (s.sequence + 1) & 4095 + if s.sequence == 0 { + timestamp = s.waitNextMillis() + } + } else { + s.sequence = 0 + } + + s.lastTimestamp = timestamp + id := (timestamp << 22) | (int64(s.workerID) << 10) | int64(s.sequence) + if raw { + return fmt.Sprintf("%d", id), nil + } + now := time.Now() + return fmt.Sprintf("%d%02d%02d%d", now.Year(), now.Month(), now.Day(), id), nil +} + +func (s *Snowflake) waitNextMillis() int64 { + timestamp := time.Now().UnixNano() / 1000000 + for timestamp <= s.lastTimestamp { + timestamp = time.Now().UnixNano() / 1000000 + } + return timestamp +} diff --git a/service/types.go b/service/types.go new file mode 100644 index 0000000..15a538a --- /dev/null +++ b/service/types.go @@ -0,0 +1,4 @@ +package service + +const RewritePromptTemplate = "Please rewrite the following text into AI painting prompt words, and please try to add detailed description of the picture, painting style, scene, rendering effect, picture light and other creative elements. Just output the final prompt word directly. Do not output any explanation lines. The text to be rewritten is: [%s]" +const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]" diff --git a/service/wx/bot.go b/service/wx/bot.go new file mode 100644 index 0000000..4738458 --- /dev/null +++ b/service/wx/bot.go @@ -0,0 +1,101 @@ +package wx + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + logger2 "geekai/logger" + "geekai/store/model" + "github.com/eatmoreapple/openwechat" + "github.com/skip2/go-qrcode" + "gorm.io/gorm" + "os" + "strconv" +) + +// 微信收款机器人 +var logger = logger2.GetLogger() + +type Bot struct { + bot *openwechat.Bot + token string + db *gorm.DB +} + +func NewWeChatBot(db *gorm.DB) *Bot { + bot := openwechat.DefaultBot(openwechat.Desktop) + return &Bot{ + bot: bot, + db: db, + } +} + +func (b *Bot) Run() error { + logger.Info("Starting WeChat Bot...") + + // set message handler + b.bot.MessageHandler = func(msg *openwechat.Message) { + b.messageHandler(msg) + } + // scan code login callback + b.bot.UUIDCallback = b.qrCodeCallBack + debug, err := strconv.ParseBool(os.Getenv("APP_DEBUG")) + if debug { + reloadStorage := openwechat.NewJsonFileHotReloadStorage("storage.json") + err = b.bot.HotLogin(reloadStorage, true) + } else { + err = b.bot.Login() + } + if err != nil { + return err + } + + logger.Info("微信登录成功!") + return nil +} + +// message handler +func (b *Bot) messageHandler(msg *openwechat.Message) { + sender, err := msg.Sender() + if err != nil { + return + } + + // 只处理微信支付的推送消息 + if sender.NickName == "微信支付" || + msg.MsgType == openwechat.MsgTypeApp || + msg.AppMsgType == openwechat.AppMsgTypeUrl { + // 解析支付金额 + message := parseTransactionMessage(msg.Content) + transaction := extractTransaction(message) + logger.Infof("解析到收款信息:%+v", transaction) + if transaction.TransId != "" { + var item model.Reward + res := b.db.Where("tx_id = ?", transaction.TransId).First(&item) + if item.Id > 0 { + logger.Error("当前交易 ID 己经存在!") + return + } + + res = b.db.Create(&model.Reward{ + TxId: transaction.TransId, + Amount: transaction.Amount, + Remark: transaction.Remark, + Status: false, + }) + if res.Error != nil { + logger.Errorf("交易保存失败: %v", res.Error) + } + } + } +} + +func (b *Bot) qrCodeCallBack(uuid string) { + logger.Info("请使用微信扫描下面二维码登录") + q, _ := qrcode.New("https://login.weixin.qq.com/l/"+uuid, qrcode.Medium) + logger.Info(q.ToString(true)) +} diff --git a/service/wx/tranaction.go b/service/wx/tranaction.go new file mode 100644 index 0000000..2dfb529 --- /dev/null +++ b/service/wx/tranaction.go @@ -0,0 +1,112 @@ +package wx + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "encoding/xml" + "net/url" + "strconv" + "strings" +) + +// Message 转账消息 +type Message struct { + Des string + Url string +} + +// Transaction 解析后的交易信息 +type Transaction struct { + TransId string `json:"trans_id"` // 微信转账交易 ID + Amount float64 `json:"amount"` // 微信转账交易金额 + Remark string `json:"remark"` // 转账备注 +} + +// 解析微信转账消息 +func parseTransactionMessage(xmlData string) *Message { + decoder := xml.NewDecoder(strings.NewReader(xmlData)) + message := Message{} + for { + token, err := decoder.Token() + if err != nil { + break + } + + switch se := token.(type) { + case xml.StartElement: + var value string + if se.Name.Local == "des" && message.Des == "" { + if err := decoder.DecodeElement(&value, &se); err == nil { + message.Des = strings.TrimSpace(value) + } + break + } + if se.Name.Local == "weapp_path" || se.Name.Local == "url" { + if err := decoder.DecodeElement(&value, &se); err == nil { + if strings.Contains(value, "?trans_id=") || strings.Contains(value, "?id=") { + message.Url = value + } + } + break + } + } + } + + // 兼容旧版消息记录 + if message.Url == "" { + var msg struct { + XMLName xml.Name `xml:"msg"` + AppMsg struct { + Des string `xml:"des"` + Url string `xml:"url"` + } `xml:"appmsg"` + } + if err := xml.Unmarshal([]byte(xmlData), &msg); err == nil { + message.Url = msg.AppMsg.Url + } + } + return &message +} + +// 导出交易信息 +func extractTransaction(message *Message) Transaction { + var tx = Transaction{} + // 导出交易金额和备注 + lines := strings.Split(message.Des, "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if len(line) == 0 { + continue + } + // 解析收款金额 + prefix := "收款金额¥" + if strings.HasPrefix(line, prefix) { + if value, err := strconv.ParseFloat(line[len(prefix):], 64); err == nil { + tx.Amount = value + continue + } + } + // 解析收款备注 + prefix = "付款方备注" + if strings.HasPrefix(line, prefix) { + tx.Remark = line[len(prefix):] + break + } + } + + // 解析交易 ID + parse, err := url.Parse(message.Url) + if err == nil { + tx.TransId = parse.Query().Get("id") + if tx.TransId == "" { + tx.TransId = parse.Query().Get("trans_id") + } + } + + return tx +} diff --git a/service/xxl_job_service.go b/service/xxl_job_service.go new file mode 100644 index 0000000..14fec1d --- /dev/null +++ b/service/xxl_job_service.go @@ -0,0 +1,196 @@ +package service + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "context" + "fmt" + "geekai/core/types" + logger2 "geekai/logger" + "geekai/store/model" + "geekai/utils" + "github.com/xxl-job/xxl-job-executor-go" + "gorm.io/gorm" + "time" +) + +var logger = logger2.GetLogger() + +type XXLJobExecutor struct { + executor xxl.Executor + db *gorm.DB +} + +func NewXXLJobExecutor(config *types.AppConfig, db *gorm.DB) *XXLJobExecutor { + if !config.XXLConfig.Enabled { + logger.Info("XXL-JOB service is disabled") + return nil + } + + exec := xxl.NewExecutor( + xxl.ServerAddr(config.XXLConfig.ServerAddr), + xxl.AccessToken(config.XXLConfig.AccessToken), //请求令牌(默认为空) + xxl.ExecutorIp(config.XXLConfig.ExecutorIp), //可自动获取 + xxl.ExecutorPort(config.XXLConfig.ExecutorPort), //默认9999(非必填) + xxl.RegistryKey(config.XXLConfig.RegistryKey), //执行器名称 + xxl.SetLogger(&customLogger{}), //自定义日志 + ) + exec.Init() + return &XXLJobExecutor{executor: exec, db: db} +} + +func (e *XXLJobExecutor) Run() error { + e.executor.RegTask("ClearOrders", e.ClearOrders) + e.executor.RegTask("ResetVipPower", e.ResetVipPower) + e.executor.RegTask("ResetUserPower", e.ResetUserPower) + return e.executor.Run() +} + +// ClearOrders 清理未支付的订单,如果没有抛出异常则表示执行成功 +func (e *XXLJobExecutor) ClearOrders(cxt context.Context, param *xxl.RunReq) (msg string) { + logger.Info("执行清理未支付订单...") + var sysConfig model.Config + res := e.db.Where("marker", "system").First(&sysConfig) + if res.Error != nil { + return "error with get system config: " + res.Error.Error() + } + + var config types.SystemConfig + err := utils.JsonDecode(sysConfig.Config, &config) + if err != nil { + return "error with decode system config: " + err.Error() + } + + if config.OrderPayTimeout == 0 { // 默认未支付订单的生命周期为 30 分钟 + config.OrderPayTimeout = 1800 + } + timeout := time.Now().Unix() - int64(config.OrderPayTimeout) + start := utils.Stamp2str(timeout) + // 这里不是用软删除,而是永久删除订单 + res = e.db.Unscoped().Where("status IN ? AND created_at < ?", []types.OrderStatus{types.OrderNotPaid, types.OrderScanned}, start).Delete(&model.Order{}) + logger.Infof("Clear order successfully, affect rows: %d", res.RowsAffected) + return "success" +} + +// ResetVipPower 重置VIP会员算力 +// 自动将 VIP 会员的算力补充到每月赠送的最大值 +func (e *XXLJobExecutor) ResetVipPower(cxt context.Context, param *xxl.RunReq) (msg string) { + logger.Info("开始进行月底账号盘点...") + var users []model.User + res := e.db.Where("vip", 1).Where("status", 1).Find(&users) + if res.Error != nil { + return "No vip users found" + } + + var sysConfig model.Config + res = e.db.Where("marker", "system").First(&sysConfig) + if res.Error != nil { + return "error with get system config: " + res.Error.Error() + } + + var config types.SystemConfig + err := utils.JsonDecode(sysConfig.Config, &config) + if err != nil { + return "error with decode system config: " + err.Error() + } + + for _, u := range users { + // 处理过期的 VIP + if u.ExpiredTime > 0 && u.ExpiredTime <= time.Now().Unix() { + u.Vip = false + e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("vip", false) + continue + } + // update user + tx := e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("power", gorm.Expr("power + ?", config.VipMonthPower)) + // 记录算力变动日志 + if tx.Error == nil { + var user model.User + e.db.Where("id", u.Id).First(&user) + e.db.Create(&model.PowerLog{ + UserId: u.Id, + Username: u.Username, + Type: types.PowerRecharge, + Amount: config.VipMonthPower, + Mark: types.PowerAdd, + Balance: user.Power, + Model: "系统盘点", + Remark: fmt.Sprintf("VIP会员每月算力派发,:%d", config.VipMonthPower), + CreatedAt: time.Now(), + }) + } + } + logger.Info("月底盘点完成!") + return "success" +} + +func (e *XXLJobExecutor) ResetUserPower(cxt context.Context, param *xxl.RunReq) (msg string) { + logger.Info("今日算力派发开始:", time.Now()) + var users []model.User + res := e.db.Where("status", 1).Find(&users) + if res.Error != nil { + return "No matching users" + } + + var sysConfig model.Config + res = e.db.Where("marker", "system").First(&sysConfig) + if res.Error != nil { + return "error with get system config: " + res.Error.Error() + } + + var config types.SystemConfig + err := utils.JsonDecode(sysConfig.Config, &config) + if err != nil { + return "error with decode system config: " + err.Error() + } + + if config.DailyPower <= 0 { + return "success" + } + + var counter = 0 + var totalPower = 0 + for _, u := range users { + if u.Power >= config.DailyPower { + continue + } + var power = config.DailyPower - u.Power + // update user + tx := e.db.Model(&model.User{}).Where("id", u.Id).UpdateColumn("power", gorm.Expr("power + ?", power)) + // 记录算力充值日志 + if tx.Error == nil { + var user model.User + e.db.Where("id", u.Id).First(&user) + e.db.Create(&model.PowerLog{ + UserId: u.Id, + Username: u.Username, + Type: types.PowerGift, + Amount: power, + Mark: types.PowerAdd, + Balance: user.Power, + Model: "系统赠送", + Remark: fmt.Sprintf("系统每日算力派发,今日额度:%d", config.DailyPower), + CreatedAt: time.Now(), + }) + } + counter++ + totalPower += power + } + logger.Infof("今日派发算力结束!累计派发 %d 人,累计派发算力:%d", counter, totalPower) + return "success" +} + +type customLogger struct{} + +func (l *customLogger) Info(format string, a ...interface{}) { + logger.Debugf(format, a...) +} + +func (l *customLogger) Error(format string, a ...interface{}) { + logger.Errorf(format, a...) +} diff --git a/store/leveldb.go b/store/leveldb.go new file mode 100644 index 0000000..5de0042 --- /dev/null +++ b/store/leveldb.go @@ -0,0 +1,116 @@ +package store + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "encoding/json" + + "github.com/syndtr/goleveldb/leveldb" + "github.com/syndtr/goleveldb/leveldb/util" +) + +type LevelDB struct { + driver *leveldb.DB +} + +func NewLevelDB() (*LevelDB, error) { + db, err := leveldb.OpenFile("data", nil) + if err != nil { + return nil, err + } + return &LevelDB{ + driver: db, + }, nil +} + +func (db *LevelDB) Put(key string, value interface{}) error { + var byteData []byte + if v, ok := value.(string); ok { + byteData = []byte(v) + } else { + b, err := json.Marshal(value) + if err != nil { + return err + } + byteData = b + } + return db.driver.Put([]byte(key), byteData, nil) +} + +func (db *LevelDB) Get(key string, dist interface{}) error { + bytes, err := db.driver.Get([]byte(key), nil) + if err != nil { + return err + } + return json.Unmarshal(bytes, dist) +} + +func (db *LevelDB) Search(prefix string) []string { + var items = make([]string, 0) + iter := db.driver.NewIterator(util.BytesPrefix([]byte(prefix)), nil) + defer iter.Release() + + for iter.Next() { + items = append(items, string(iter.Value())) + } + return items +} + +type PageVo struct { + Items []string + Page int + PageSize int + Total int + TotalPage int +} + +func (db *LevelDB) SearchPage(prefix string, page int, pageSize int) *PageVo { + var items = make([]string, 0) + iter := db.driver.NewIterator(util.BytesPrefix([]byte(prefix)), nil) + defer iter.Release() + + res := &PageVo{Page: page, PageSize: pageSize} + // 计算数据总数和总页数 + total := 0 + for iter.Next() { + total++ + } + res.TotalPage = (total + pageSize - 1) / pageSize + res.Total = total + + // 计算目标页码的起始和结束位置 + start := (page - 1) * pageSize + if start > total { + return nil + } + end := start + pageSize + if end > total { + end = total + } + + // 跳转到目标页码的起始位置 + count := 0 + for iter.Next() { + if count >= start { + items = append(items, string(iter.Value())) + } + count++ + } + iter.Release() + res.Items = items + return res +} + +func (db *LevelDB) Delete(key string) error { + return db.driver.Delete([]byte(key), nil) +} + +// Close release resources +func (db *LevelDB) Close() error { + return db.driver.Close() +} diff --git a/store/model/admin_user.go b/store/model/admin_user.go new file mode 100644 index 0000000..ffeccd1 --- /dev/null +++ b/store/model/admin_user.go @@ -0,0 +1,11 @@ +package model + +type AdminUser struct { + BaseModel + Username string + Password string + Salt string // 密码盐 + Status bool `gorm:"default:true"` // 当前状态 + LastLoginAt int64 // 最后登录时间 + LastLoginIp string // 最后登录 IP +} diff --git a/store/model/api_key.go b/store/model/api_key.go new file mode 100644 index 0000000..fb7ae1d --- /dev/null +++ b/store/model/api_key.go @@ -0,0 +1,14 @@ +package model + +// ApiKey OpenAI API 模型 +type ApiKey struct { + BaseModel + Platform string + Name string + Type string // 用途 chat => 聊天,img => 绘图 + Value string // API Key 的值 + ApiURL string // 当前 KEY 的 API 地址 + Enabled bool // 是否启用 + ProxyURL string // 代理地址 + LastUsedAt int64 // 最后使用时间 +} diff --git a/store/model/base.go b/store/model/base.go new file mode 100644 index 0000000..5246b4d --- /dev/null +++ b/store/model/base.go @@ -0,0 +1,9 @@ +package model + +import "time" + +type BaseModel struct { + Id uint `gorm:"primarykey;column:id"` + CreatedAt time.Time + UpdatedAt time.Time +} diff --git a/store/model/chat_history.go b/store/model/chat_history.go new file mode 100644 index 0000000..36abeb4 --- /dev/null +++ b/store/model/chat_history.go @@ -0,0 +1,21 @@ +package model + +import "gorm.io/gorm" + +type ChatMessage struct { + BaseModel + ChatId string // 会话 ID + UserId uint // 用户 ID + RoleId uint // 角色 ID + Model string // AI模型 + Type string + Icon string + Tokens int + Content string + UseContext bool // 是否可以作为聊天上下文 + DeletedAt gorm.DeletedAt +} + +func (ChatMessage) TableName() string { + return "chatgpt_chat_history" +} diff --git a/store/model/chat_item.go b/store/model/chat_item.go new file mode 100644 index 0000000..80b4590 --- /dev/null +++ b/store/model/chat_item.go @@ -0,0 +1,14 @@ +package model + +import "gorm.io/gorm" + +type ChatItem struct { + BaseModel + ChatId string `gorm:"column:chat_id;unique"` // 会话 ID + UserId uint // 用户 ID + RoleId uint // 角色 ID + ModelId uint // 模型 ID + Model string // 模型 + Title string // 会话标题 + DeletedAt gorm.DeletedAt +} diff --git a/store/model/chat_model.go b/store/model/chat_model.go new file mode 100644 index 0000000..134655f --- /dev/null +++ b/store/model/chat_model.go @@ -0,0 +1,16 @@ +package model + +type ChatModel struct { + BaseModel + Platform string + Name string + Value string // API Key 的值 + SortNum int + Enabled bool + Power int // 每次对话消耗算力 + Open bool // 是否开放模型给所有人使用 + MaxTokens int // 最大响应长度 + MaxContext int // 最大上下文长度 + Temperature float32 // 模型温度 + KeyId int // 绑定 API KEY ID +} diff --git a/store/model/chat_role.go b/store/model/chat_role.go new file mode 100644 index 0000000..50e438b --- /dev/null +++ b/store/model/chat_role.go @@ -0,0 +1,13 @@ +package model + +type ChatRole struct { + BaseModel + Key string `gorm:"column:marker;unique"` // 角色唯一标识 + Name string // 角色名称 + Context string `gorm:"column:context_json"` // 角色语料信息 json + HelloMsg string // 打招呼的消息 + Icon string // 角色聊天图标 + Enable bool // 是否启用被启用 + SortNum int //排序数字 + ModelId int // 绑定模型ID,绑定模型ID的角色只能用指定的模型来问答 +} diff --git a/store/model/config.go b/store/model/config.go new file mode 100644 index 0000000..3b43cbe --- /dev/null +++ b/store/model/config.go @@ -0,0 +1,7 @@ +package model + +type Config struct { + Id uint `gorm:"primarykey;column:id"` + Key string `gorm:"column:marker;unique"` + Config string `gorm:"column:config_json"` +} diff --git a/store/model/dalle_job.go b/store/model/dalle_job.go new file mode 100644 index 0000000..de7a13a --- /dev/null +++ b/store/model/dalle_job.go @@ -0,0 +1,16 @@ +package model + +import "time" + +type DallJob struct { + Id uint `gorm:"primarykey;column:id"` + UserId uint + Prompt string + ImgURL string + OrgURL string + Publish bool + Power int + Progress int + ErrMsg string + CreatedAt time.Time +} diff --git a/store/model/file.go b/store/model/file.go new file mode 100644 index 0000000..56fe424 --- /dev/null +++ b/store/model/file.go @@ -0,0 +1,14 @@ +package model + +import "time" + +type File struct { + Id uint `gorm:"primarykey;column:id"` + UserId int + Name string + ObjKey string + URL string + Ext string + Size int64 + CreatedAt time.Time +} diff --git a/store/model/function.go b/store/model/function.go new file mode 100644 index 0000000..2a8b80f --- /dev/null +++ b/store/model/function.go @@ -0,0 +1,12 @@ +package model + +type Function struct { + Id uint `gorm:"primarykey;column:id"` + Name string + Label string + Description string + Parameters string + Action string + Token string + Enabled bool +} diff --git a/store/model/invite_code.go b/store/model/invite_code.go new file mode 100644 index 0000000..588904d --- /dev/null +++ b/store/model/invite_code.go @@ -0,0 +1,12 @@ +package model + +import "time" + +type InviteCode struct { + Id uint `gorm:"primarykey;column:id"` + UserId uint + Code string + Hits int // 点击次数 + RegNum int // 注册人数 + CreatedAt time.Time +} diff --git a/store/model/invite_log.go b/store/model/invite_log.go new file mode 100644 index 0000000..22052b2 --- /dev/null +++ b/store/model/invite_log.go @@ -0,0 +1,15 @@ +package model + +import ( + "time" +) + +type InviteLog struct { + Id uint `gorm:"primarykey;column:id"` + InviterId uint + UserId uint + Username string + InviteCode string + Remark string + CreatedAt time.Time +} diff --git a/store/model/menu.go b/store/model/menu.go new file mode 100644 index 0000000..e215e20 --- /dev/null +++ b/store/model/menu.go @@ -0,0 +1,11 @@ +package model + +// Menu 系统菜单 +type Menu struct { + Id uint `gorm:"primarykey;column:id"` + Name string // 菜单名称 + Icon string // 菜单图标 + URL string // 菜单跳转地址 + SortNum int // 排序 + Enabled bool // 启用状态 +} diff --git a/store/model/mj_job.go b/store/model/mj_job.go new file mode 100644 index 0000000..b4e03a6 --- /dev/null +++ b/store/model/mj_job.go @@ -0,0 +1,27 @@ +package model + +import "time" + +type MidJourneyJob struct { + Id uint `gorm:"primarykey;column:id"` + Type string + UserId int + TaskId string + ChannelId string + MessageId string + ReferenceId string + ImgURL string + OrgURL string // 原图地址 + Hash string // message hash + Progress int + Prompt string + UseProxy bool // 是否使用反代加载图片 + Publish bool //是否发布图片到画廊 + ErrMsg string // 报错信息 + Power int // 消耗算力 + CreatedAt time.Time +} + +func (MidJourneyJob) TableName() string { + return "chatgpt_mj_jobs" +} diff --git a/store/model/order.go b/store/model/order.go new file mode 100644 index 0000000..a1c6929 --- /dev/null +++ b/store/model/order.go @@ -0,0 +1,23 @@ +package model + +import ( + "geekai/core/types" + "gorm.io/gorm" +) + +// Order 充值订单 +type Order struct { + BaseModel + UserId uint + ProductId uint + Username string + OrderNo string + TradeNo string + Subject string + Amount float64 + Status types.OrderStatus + Remark string + PayTime int64 + PayWay string // 支付方式 + DeletedAt gorm.DeletedAt +} diff --git a/store/model/power_log.go b/store/model/power_log.go new file mode 100644 index 0000000..fd6c322 --- /dev/null +++ b/store/model/power_log.go @@ -0,0 +1,20 @@ +package model + +import ( + "geekai/core/types" + "time" +) + +// PowerLog 算力消费日志 +type PowerLog struct { + Id uint `gorm:"primarykey;column:id"` + UserId uint + Username string + Type types.PowerType + Amount int + Balance int + Model string // 模型 + Remark string // 备注 + Mark types.PowerMark // 资金类型 + CreatedAt time.Time +} diff --git a/store/model/product.go b/store/model/product.go new file mode 100644 index 0000000..66e35d1 --- /dev/null +++ b/store/model/product.go @@ -0,0 +1,14 @@ +package model + +// Product 充值产品 +type Product struct { + BaseModel + Name string + Price float64 + Discount float64 + Days int + Power int + Enabled bool + Sales int + SortNum int +} diff --git a/store/model/reward.go b/store/model/reward.go new file mode 100644 index 0000000..43b9c8c --- /dev/null +++ b/store/model/reward.go @@ -0,0 +1,13 @@ +package model + +// 用户打赏 + +type Reward struct { + BaseModel + UserId uint // 用户 ID + TxId string // 交易ID + Amount float64 // 打赏金额 + Remark string // 打赏备注 + Status bool // 核销状态 + Exchange string // 众筹兑换详情,JSON +} diff --git a/store/model/sd_job.go b/store/model/sd_job.go new file mode 100644 index 0000000..8542c30 --- /dev/null +++ b/store/model/sd_job.go @@ -0,0 +1,22 @@ +package model + +import "time" + +type SdJob struct { + Id uint `gorm:"primarykey;column:id"` + Type string + UserId int + TaskId string + ImgURL string + Progress int + Prompt string + Params string + Publish bool //是否发布图片到画廊 + ErrMsg string // 报错信息 + Power int // 消耗算力 + CreatedAt time.Time +} + +func (SdJob) TableName() string { + return "chatgpt_sd_jobs" +} diff --git a/store/model/user.go b/store/model/user.go new file mode 100644 index 0000000..41d0990 --- /dev/null +++ b/store/model/user.go @@ -0,0 +1,19 @@ +package model + +type User struct { + BaseModel + Username string + Nickname string + Password string + Avatar string + Salt string // 密码盐 + Power int // 剩余算力 + ChatConfig string `gorm:"column:chat_config_json"` // 聊天配置 json + ChatRoles string `gorm:"column:chat_roles_json"` // 聊天角色 + ChatModels string `gorm:"column:chat_models_json"` // AI 模型,不同的用户拥有不同的聊天模型 + ExpiredTime int64 // 账户到期时间 + Status bool `gorm:"default:true"` // 当前状态 + LastLoginAt int64 // 最后登录时间 + LastLoginIp string // 最后登录 IP + Vip bool // 是否 VIP 会员 +} diff --git a/store/model/user_login_log.go b/store/model/user_login_log.go new file mode 100644 index 0000000..87596d5 --- /dev/null +++ b/store/model/user_login_log.go @@ -0,0 +1,9 @@ +package model + +type UserLoginLog struct { + BaseModel + UserId uint + Username string + LoginIp string + LoginAddress string +} diff --git a/store/mysql.go b/store/mysql.go new file mode 100644 index 0000000..70aba96 --- /dev/null +++ b/store/mysql.go @@ -0,0 +1,44 @@ +package store + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core/types" + "gorm.io/driver/mysql" + "gorm.io/gorm" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" + "time" +) + +func NewGormConfig() *gorm.Config { + return &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + NamingStrategy: schema.NamingStrategy{ + TablePrefix: "chatgpt_", // 设置表前缀 + SingularTable: false, // 使用单数表名形式 + }, + } +} + +func NewMysql(config *gorm.Config, appConfig *types.AppConfig) (*gorm.DB, error) { + db, err := gorm.Open(mysql.Open(appConfig.MysqlDns), config) + if err != nil { + return nil, err + } + + sqlDB, err := db.DB() + if err != nil { + return nil, err + } + sqlDB.SetMaxIdleConns(32) + sqlDB.SetMaxOpenConns(512) + sqlDB.SetConnMaxLifetime(time.Hour) + + return db, nil +} diff --git a/store/redis.go b/store/redis.go new file mode 100644 index 0000000..f2e4814 --- /dev/null +++ b/store/redis.go @@ -0,0 +1,27 @@ +package store + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "context" + "geekai/core/types" + "github.com/go-redis/redis/v8" +) + +func NewRedisClient(config *types.AppConfig) (*redis.Client, error) { + client := redis.NewClient(&redis.Options{ + Addr: config.Redis.Url(), + Password: config.Redis.Password, + DB: config.Redis.DB, + }) + _, err := client.Ping(context.Background()).Result() + if err != nil { + return nil, err + } + return client, nil +} diff --git a/store/redis_queue.go b/store/redis_queue.go new file mode 100644 index 0000000..3251eb5 --- /dev/null +++ b/store/redis_queue.go @@ -0,0 +1,48 @@ +package store + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "context" + "geekai/utils" + "github.com/go-redis/redis/v8" +) + +type RedisQueue struct { + name string + client *redis.Client + ctx context.Context +} + +func NewRedisQueue(name string, client *redis.Client) *RedisQueue { + return &RedisQueue{name: name, client: client, ctx: context.Background()} +} + +func (q *RedisQueue) RPush(value interface{}) { + q.client.RPush(q.ctx, q.name, utils.JsonEncode(value)) +} + +func (q *RedisQueue) LPush(value interface{}) { + q.client.LPush(q.ctx, q.name, utils.JsonEncode(value)) +} + +func (q *RedisQueue) LPop(value interface{}) error { + result, err := q.client.BLPop(q.ctx, 0, q.name).Result() + if err != nil { + return err + } + return utils.JsonDecode(result[1], value) +} + +func (q *RedisQueue) RPop(value interface{}) error { + result, err := q.client.BRPop(q.ctx, 0, q.name).Result() + if err != nil { + return err + } + return utils.JsonDecode(result[1], value) +} diff --git a/store/vo/admin_user.go b/store/vo/admin_user.go new file mode 100644 index 0000000..24403be --- /dev/null +++ b/store/vo/admin_user.go @@ -0,0 +1,10 @@ +package vo + +type AdminUser struct { + BaseVo + Username string `json:"username"` + Status bool `json:"status"` // 当前状态 + LastLoginAt int64 `json:"last_login_at"` // 最后登录时间 + LastLoginIp string `json:"last_login_ip"` // 最后登录 IP + RoleIds interface{} `json:"role_ids"` //角色ids +} diff --git a/store/vo/api_key.go b/store/vo/api_key.go new file mode 100644 index 0000000..7321b13 --- /dev/null +++ b/store/vo/api_key.go @@ -0,0 +1,14 @@ +package vo + +// ApiKey OpenAI API 模型 +type ApiKey struct { + BaseVo + Platform string `json:"platform"` + Name string `json:"name"` + Type string `json:"type"` + Value string `json:"value"` // API Key 的值 + ApiURL string `json:"api_url"` + Enabled bool `json:"enabled"` + ProxyURL string `json:"proxy_url"` + LastUsedAt int64 `json:"last_used_at"` // 最后使用时间 +} diff --git a/store/vo/base.go b/store/vo/base.go new file mode 100644 index 0000000..1b467c3 --- /dev/null +++ b/store/vo/base.go @@ -0,0 +1,7 @@ +package vo + +type BaseVo struct { + Id uint `json:"id"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` +} diff --git a/store/vo/chat_history.go b/store/vo/chat_history.go new file mode 100644 index 0000000..3f534f3 --- /dev/null +++ b/store/vo/chat_history.go @@ -0,0 +1,14 @@ +package vo + +type HistoryMessage struct { + BaseVo + ChatId string `json:"chat_id"` + UserId uint `json:"user_id"` + RoleId uint `json:"role_id"` + Model string `json:"model"` + Type string `json:"type"` + Icon string `json:"icon"` + Tokens int `json:"tokens"` + Content string `json:"content"` + UseContext bool `json:"use_context"` +} diff --git a/store/vo/chat_item.go b/store/vo/chat_item.go new file mode 100644 index 0000000..4ac66fc --- /dev/null +++ b/store/vo/chat_item.go @@ -0,0 +1,13 @@ +package vo + +type ChatItem struct { + BaseVo + UserId uint `json:"user_id"` + Icon string `json:"icon"` + RoleId uint `json:"role_id"` + RoleName string `json:"role_name"` + ChatId string `json:"chat_id"` + ModelId uint `json:"model_id"` + Model string `json:"model"` + Title string `json:"title"` +} diff --git a/store/vo/chat_model.go b/store/vo/chat_model.go new file mode 100644 index 0000000..4fb2105 --- /dev/null +++ b/store/vo/chat_model.go @@ -0,0 +1,17 @@ +package vo + +type ChatModel struct { + BaseVo + Platform string `json:"platform"` + Name string `json:"name"` + Value string `json:"value"` + Enabled bool `json:"enabled"` + SortNum int `json:"sort_num"` + Power int `json:"power"` + Open bool `json:"open"` + MaxTokens int `json:"max_tokens"` // 最大响应长度 + MaxContext int `json:"max_context"` // 最大上下文长度 + Temperature float32 `json:"temperature"` // 模型温度 + KeyId int `json:"key_id"` + KeyName string `json:"key_name"` +} diff --git a/store/vo/chat_role.go b/store/vo/chat_role.go new file mode 100644 index 0000000..4bd530f --- /dev/null +++ b/store/vo/chat_role.go @@ -0,0 +1,16 @@ +package vo + +import "geekai/core/types" + +type ChatRole struct { + BaseVo + Key string `json:"key"` // 角色唯一标识 + Name string `json:"name"` // 角色名称 + Context []types.Message `json:"context"` // 角色语料信息 + HelloMsg string `json:"hello_msg"` // 打招呼的消息 + Icon string `json:"icon"` // 角色聊天图标 + Enable bool `json:"enable"` // 是否启用被启用 + SortNum int `json:"sort"` // 排序 + ModelId int `json:"model_id"` // 绑定模型 ID + ModelName string `json:"model_name"` // 模型名称 +} diff --git a/store/vo/config.go b/store/vo/config.go new file mode 100644 index 0000000..a0e7907 --- /dev/null +++ b/store/vo/config.go @@ -0,0 +1,9 @@ +package vo + +import "geekai/core/types" + +type Config struct { + Id uint `json:"id"` + Key string `json:"key"` + SystemConfig types.SystemConfig `json:"system_config"` +} diff --git a/store/vo/dalle_job.go b/store/vo/dalle_job.go new file mode 100644 index 0000000..28a6906 --- /dev/null +++ b/store/vo/dalle_job.go @@ -0,0 +1,14 @@ +package vo + +type DallJob struct { + Id uint `json:"id"` + UserId int `json:"user_id"` + Prompt string `json:"prompt"` + ImgURL string `json:"img_url"` + OrgURL string `json:"org_url"` + Publish bool `json:"publish"` + Power int `json:"power"` + Progress int `json:"progress"` + ErrMsg string `json:"err_msg"` + CreatedAt int64 `json:"created_at"` +} diff --git a/store/vo/file.go b/store/vo/file.go new file mode 100644 index 0000000..c5e83dc --- /dev/null +++ b/store/vo/file.go @@ -0,0 +1,12 @@ +package vo + +type File struct { + Id uint `json:"id"` + UserId uint `json:"user_id"` + Name string `json:"name"` + ObjKey string `json:"obj_key"` + URL string `json:"url"` + Ext string `json:"ext"` + Size int64 `json:"size"` + CreatedAt int64 `json:"created_at"` +} diff --git a/store/vo/function.go b/store/vo/function.go new file mode 100644 index 0000000..eb4323e --- /dev/null +++ b/store/vo/function.go @@ -0,0 +1,23 @@ +package vo + +type Parameters struct { + Type string `json:"type"` + Required []string `json:"required,omitempty"` + Properties map[string]Property `json:"properties"` +} + +type Property struct { + Type string `json:"type"` + Description string `json:"description"` +} + +type Function struct { + Id uint `json:"id"` + Name string `json:"name"` + Label string `json:"label"` + Description string `json:"description"` + Parameters Parameters `json:"parameters"` + Action string `json:"action"` + Token string `json:"token"` + Enabled bool `json:"enabled"` +} diff --git a/store/vo/invite_code.go b/store/vo/invite_code.go new file mode 100644 index 0000000..122bdc0 --- /dev/null +++ b/store/vo/invite_code.go @@ -0,0 +1,10 @@ +package vo + +type InviteCode struct { + Id uint `json:"id"` + UserId uint `json:"user_id"` + Code string `json:"code"` + Hits int `json:"hits"` + RegNum int `json:"reg_num"` + CreatedAt int64 `json:"created_at"` +} diff --git a/store/vo/invite_log.go b/store/vo/invite_log.go new file mode 100644 index 0000000..3be80c3 --- /dev/null +++ b/store/vo/invite_log.go @@ -0,0 +1,11 @@ +package vo + +type InviteLog struct { + Id uint `json:"id"` + InviterId uint `json:"inviter_id"` + UserId uint `json:"user_id"` + Username string `json:"username"` + InviteCode string `json:"invite_code"` + Remark string `json:"remark"` + CreatedAt int64 `json:"created_at"` +} diff --git a/store/vo/menu.go b/store/vo/menu.go new file mode 100644 index 0000000..c9975a4 --- /dev/null +++ b/store/vo/menu.go @@ -0,0 +1,11 @@ +package vo + +// Menu 系统菜单 +type Menu struct { + Id uint `json:"id"` + Name string `json:"name"` + Icon string `json:"icon"` + URL string `json:"url"` + SortNum int `json:"sort_num"` + Enabled bool `json:"enabled"` +} diff --git a/store/vo/mj_job.go b/store/vo/mj_job.go new file mode 100644 index 0000000..59ec11c --- /dev/null +++ b/store/vo/mj_job.go @@ -0,0 +1,23 @@ +package vo + +import "time" + +type MidJourneyJob struct { + Id uint `json:"id"` + Type string `json:"type"` + UserId int `json:"user_id"` + ChannelId string `json:"channel_id"` + TaskId string `json:"task_id"` + MessageId string `json:"message_id"` + ReferenceId string `json:"reference_id"` + ImgURL string `json:"img_url"` + OrgURL string `json:"org_url"` + Hash string `json:"hash"` + Progress int `json:"progress"` + Prompt string `json:"prompt"` + UseProxy bool `json:"use_proxy"` + Publish bool `json:"publish"` + ErrMsg string `json:"err_msg"` + Power int `json:"power"` + CreatedAt time.Time `json:"created_at"` +} diff --git a/store/vo/order.go b/store/vo/order.go new file mode 100644 index 0000000..c076d00 --- /dev/null +++ b/store/vo/order.go @@ -0,0 +1,20 @@ +package vo + +import ( + "geekai/core/types" +) + +type Order struct { + BaseVo + UserId uint `json:"user_id"` + ProductId uint `json:"product_id"` + Username string `json:"username"` + OrderNo string `json:"order_no"` + TradeNo string `json:"trade_no"` + Subject string `json:"subject"` + Amount float64 `json:"amount"` + Status types.OrderStatus `json:"status"` + PayTime int64 `json:"pay_time"` + PayWay string `json:"pay_way"` + Remark types.OrderRemark `json:"remark"` +} diff --git a/store/vo/page.go b/store/vo/page.go new file mode 100644 index 0000000..b47d49e --- /dev/null +++ b/store/vo/page.go @@ -0,0 +1,22 @@ +package vo + +import "math" + +type Page struct { + Items interface{} `json:"items"` + Page int `json:"page"` + PageSize int `json:"page_size"` + Total int64 `json:"total"` + TotalPage int `json:"total_page"` +} + +func NewPage(total int64, page int, pageSize int, items interface{}) Page { + totalPage := math.Ceil(float64(total) / float64(pageSize)) + return Page{ + Items: items, + Page: page, + PageSize: pageSize, + Total: total, + TotalPage: int(totalPage), + } +} diff --git a/store/vo/power_log.go b/store/vo/power_log.go new file mode 100644 index 0000000..ae19a9d --- /dev/null +++ b/store/vo/power_log.go @@ -0,0 +1,17 @@ +package vo + +import "geekai/core/types" + +type PowerLog struct { + Id uint `json:"id"` + UserId uint `json:"user_id"` + Username string `json:"username"` + Type types.PowerType `json:"type"` + TypeStr string `json:"type_str"` + Amount int `json:"amount"` + Mark types.PowerMark `json:"mark"` + Balance int `json:"balance"` + Model string `json:"model"` + Remark string `json:"remark"` + CreatedAt int64 `json:"created_at"` +} diff --git a/store/vo/product.go b/store/vo/product.go new file mode 100644 index 0000000..0cc5e19 --- /dev/null +++ b/store/vo/product.go @@ -0,0 +1,13 @@ +package vo + +type Product struct { + BaseVo + Name string `json:"name"` + Price float64 `json:"price"` + Discount float64 `json:"discount"` + Days int `json:"days"` + Power int `json:"power"` + Enabled bool `json:"enabled"` + Sales int `json:"sales"` + SortNum int `json:"sort_num"` +} diff --git a/store/vo/reward.go b/store/vo/reward.go new file mode 100644 index 0000000..b3c5ac1 --- /dev/null +++ b/store/vo/reward.go @@ -0,0 +1,16 @@ +package vo + +type Reward struct { + BaseVo + UserId uint `json:"user_id"` // 用户 ID + Username string `json:"username"` + TxId string `json:"tx_id"` // 交易ID + Amount float64 `json:"amount"` // 打赏金额 + Remark string `json:"remark"` // 打赏备注 + Status bool `json:"status"` // 核销状态 + Exchange RewardExchange `json:"exchange"` +} + +type RewardExchange struct { + Power int `json:"power"` +} diff --git a/store/vo/sd_job.go b/store/vo/sd_job.go new file mode 100644 index 0000000..8d52150 --- /dev/null +++ b/store/vo/sd_job.go @@ -0,0 +1,21 @@ +package vo + +import ( + "geekai/core/types" + "time" +) + +type SdJob struct { + Id uint `json:"id"` + Type string `json:"type"` + UserId int `json:"user_id"` + TaskId string `json:"task_id"` + ImgURL string `json:"img_url"` + Params types.SdTaskParams `json:"params"` + Progress int `json:"progress"` + Prompt string `json:"prompt"` + Publish bool `json:"publish"` + ErrMsg string `json:"err_msg"` + Power int `json:"power"` + CreatedAt time.Time `json:"created_at"` +} diff --git a/store/vo/user.go b/store/vo/user.go new file mode 100644 index 0000000..560f57d --- /dev/null +++ b/store/vo/user.go @@ -0,0 +1,17 @@ +package vo + +type User struct { + BaseVo + Username string `json:"username"` + Nickname string `json:"nickname"` + Avatar string `json:"avatar"` + Salt string `json:"salt"` // 密码盐 + Power int `json:"power"` // 剩余算力 + ChatRoles []string `json:"chat_roles"` // 聊天角色集合 + ChatModels []int `json:"chat_models"` // AI模型集合 + ExpiredTime int64 `json:"expired_time"` // 账户到期时间 + Status bool `json:"status"` // 当前状态 + LastLoginAt int64 `json:"last_login_at"` // 最后登录时间 + LastLoginIp string `json:"last_login_ip"` // 最后登录 IP + Vip bool `json:"vip"` +} diff --git a/store/vo/user_login_log.go b/store/vo/user_login_log.go new file mode 100644 index 0000000..b4a094a --- /dev/null +++ b/store/vo/user_login_log.go @@ -0,0 +1,9 @@ +package vo + +type UserLoginLog struct { + BaseVo + UserId uint `json:"user_id"` + Username string `json:"username"` + LoginIp string `json:"login_ip"` + LoginAddress string `json:"login_address"` +} diff --git a/test/test.go b/test/test.go new file mode 100644 index 0000000..0a48ec9 --- /dev/null +++ b/test/test.go @@ -0,0 +1,12 @@ +package main + +import ( + "fmt" + "net/url" +) + +func main() { + text := "https://nk.img.r9it.com/chatgpt-plus/1712709360012445.png" + parse, _ := url.Parse(text) + fmt.Println(fmt.Sprintf("%s://%s", parse.Scheme, parse.Host)) +} diff --git a/utils/common.go b/utils/common.go new file mode 100644 index 0000000..142256d --- /dev/null +++ b/utils/common.go @@ -0,0 +1,212 @@ +package utils + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/lionsoul2014/ip2region/binding/golang/xdb" + "github.com/nfnt/resize" + "github.com/skip2/go-qrcode" + "image" + "image/color" + "image/draw" + "image/jpeg" + "io" + "reflect" + "strconv" + "strings" +) + +// CopyObject 拷贝对象 +func CopyObject(src interface{}, dst interface{}) error { + + srcType := reflect.TypeOf(src) + srcValue := reflect.ValueOf(src) + dstValue := reflect.ValueOf(dst).Elem() + reflect.TypeOf(dst) + for i := 0; i < srcType.NumField(); i++ { + field := srcType.Field(i) + value := dstValue.FieldByName(field.Name) + if !value.IsValid() { + continue + } + // 数据类型相同,直接赋值 + v := srcValue.FieldByName(field.Name) + if value.Type() == field.Type { + value.Set(v) + } else { + // src data type is string,dst data type is slice, map, struct + // use json decode the data + if field.Type.Kind() == reflect.String && (value.Type().Kind() == reflect.Struct || + value.Type().Kind() == reflect.Map || + value.Type().Kind() == reflect.Slice) { + pType := reflect.New(value.Type()) + v2 := pType.Interface() + err := json.Unmarshal([]byte(v.String()), &v2) + if err == nil && v2 != nil { + value.Set(reflect.ValueOf(v2).Elem()) + } + // map, struct, slice to string + } else if (field.Type.Kind() == reflect.Struct || + field.Type.Kind() == reflect.Map || + field.Type.Kind() == reflect.Slice) && value.Type().Kind() == reflect.String { + ba, err := json.Marshal(v.Interface()) + if err == nil { + val := string(ba) + if strings.Contains(val, "{") { + value.Set(reflect.ValueOf(string(ba))) + } else { + value.Set(reflect.ValueOf("")) + } + } + } else if field.Type.Kind() != value.Type().Kind() { // 不同类型的字段过滤掉 + continue + } else { // 简单数据类型的强制类型转换 + switch value.Kind() { + case reflect.Int: + case reflect.Int8: + case reflect.Int16: + case reflect.Int32: + case reflect.Int64: + value.SetInt(v.Int()) + break + case reflect.Float32: + case reflect.Float64: + value.SetFloat(v.Float()) + break + case reflect.Bool: + value.SetBool(v.Bool()) + break + } + } + + } + } + + return nil +} + +func Ip2Region(searcher *xdb.Searcher, ip string) string { + str, err := searcher.SearchByStr(ip) + if err != nil { + return "" + } + arr := strings.Split(str, "|") + if len(arr) < 3 { + return arr[0] + } + return fmt.Sprintf("%s-%s-%s", arr[0], arr[2], arr[3]) +} + +func IsEmptyValue(obj interface{}) bool { + if obj == nil { + return true + } + + v := reflect.ValueOf(obj) + switch v.Kind() { + case reflect.Ptr, reflect.Interface: + return v.IsNil() + case reflect.Array, reflect.Slice, reflect.Map, reflect.String: + return v.Len() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Complex64, reflect.Complex128: + return v.Complex() == 0 + default: + return reflect.DeepEqual(obj, reflect.Zero(reflect.TypeOf(obj)).Interface()) + } +} + +func BoolValue(str string) bool { + value, err := strconv.ParseBool(str) + if err != nil { + return false + } + return value +} + +func FloatValue(str string) float64 { + value, err := strconv.ParseFloat(str, 64) + if err != nil { + return 0 + } + return value +} + +func IntValue(str string, defaultValue int) int { + value, err := strconv.Atoi(str) + if err != nil { + return defaultValue + } + return value +} + +func ForceCovert(src any, dst interface{}) error { + b, err := json.Marshal(src) + if err != nil { + return err + } + err = json.Unmarshal(b, dst) + if err != nil { + return err + } + return nil +} + +func GenQrcode(text string, size int, logo io.Reader) ([]byte, error) { + qr, err := qrcode.New(text, qrcode.Medium) + if err != nil { + return nil, err + } + + qr.BackgroundColor = color.White + qr.ForegroundColor = color.Black + if logo == nil { + return qr.PNG(size) + } + + // 生成带Logo的二维码图像 + logoImage, _, err := image.Decode(logo) + if err != nil { + return nil, err + } + + // 缩放 Logo + scaledLogo := resize.Resize(uint(size/9), uint(size/9), logoImage, resize.Lanczos3) + // 将Logo叠加到二维码图像上 + qrWithLogo := overlayLogo(qr.Image(size), scaledLogo) + + // 将带Logo的二维码图像以JPEG格式编码为图片数据 + var buf bytes.Buffer + err = jpeg.Encode(&buf, qrWithLogo, nil) + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// 叠加Logo到图片上 +func overlayLogo(qrImage, logoImage image.Image) image.Image { + offsetX := (qrImage.Bounds().Dx() - logoImage.Bounds().Dx()) / 2 + offsetY := (qrImage.Bounds().Dy() - logoImage.Bounds().Dy()) / 2 + + combinedImage := image.NewRGBA(qrImage.Bounds()) + draw.Draw(combinedImage, qrImage.Bounds(), qrImage, image.Point{}, draw.Over) + draw.Draw(combinedImage, logoImage.Bounds().Add(image.Pt(offsetX, offsetY)), logoImage, image.Point{}, draw.Over) + + return combinedImage +} diff --git a/utils/crypto.go b/utils/crypto.go new file mode 100644 index 0000000..30fa07d --- /dev/null +++ b/utils/crypto.go @@ -0,0 +1,98 @@ +package utils + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/md5" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "io" +) + +// AesEncrypt 加密 +func AesEncrypt(keyStr string, data []byte) (string, error) { + //创建加密实例 + key := []byte(keyStr) + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + + blockSize := block.BlockSize() + encryptBytes := pkcs7Padding(data, blockSize) + result := make([]byte, len(encryptBytes)) + //使用cbc加密模式 + blockMode := cipher.NewCBCEncrypter(block, key[:blockSize]) + //执行加密 + blockMode.CryptBlocks(result, encryptBytes) + return base64.StdEncoding.EncodeToString(result), nil +} + +// AesDecrypt 解密 +func AesDecrypt(keyStr string, dataStr string) ([]byte, error) { + //创建实例 + key := []byte(keyStr) + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + data, err := base64.StdEncoding.DecodeString(dataStr) + if err != nil { + return nil, err + } + + blockSize := block.BlockSize() + blockMode := cipher.NewCBCDecrypter(block, key[:blockSize]) + result := make([]byte, len(data)) + //执行解密 + blockMode.CryptBlocks(result, data) + //去除填充 + result, err = pkcs7UnPadding(result) + if err != nil { + return nil, err + } + return result, nil +} + +func pkcs7Padding(data []byte, blockSize int) []byte { + padding := blockSize - len(data)%blockSize + padText := bytes.Repeat([]byte{byte(padding)}, padding) + return append(data, padText...) +} + +func pkcs7UnPadding(data []byte) ([]byte, error) { + length := len(data) + if length == 0 { + return nil, errors.New("empty encrypt data") + } + unPadding := int(data[length-1]) + return data[:(length - unPadding)], nil +} + +func Sha256(data string) string { + hash := sha256.New() + _, err := io.WriteString(hash, data) + if err != nil { + return "" + } + + hashValue := hash.Sum(nil) + return fmt.Sprintf("%x", hashValue) +} + +func Md5(data string) string { + md5bs := md5.Sum([]byte(data)) + return hex.EncodeToString(md5bs[:]) +} diff --git a/utils/net.go b/utils/net.go new file mode 100644 index 0000000..5f02922 --- /dev/null +++ b/utils/net.go @@ -0,0 +1,70 @@ +package utils + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "encoding/json" + "geekai/core/types" + logger2 "geekai/logger" + "io" + "net/http" + "net/url" +) + +var logger = logger2.GetLogger() + +// ReplyChunkMessage 回复客户片段端消息 +func ReplyChunkMessage(client *types.WsClient, message interface{}) { + msg, err := json.Marshal(message) + if err != nil { + logger.Errorf("Error for decoding json data: %v", err.Error()) + return + } + err = client.Send(msg) + if err != nil { + logger.Errorf("Error for reply message: %v", err.Error()) + } +} + +// ReplyMessage 回复客户端一条完整的消息 +func ReplyMessage(ws *types.WsClient, message interface{}) { + ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) + ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: message}) + ReplyChunkMessage(ws, types.WsMessage{Type: types.WsEnd}) +} + +func DownloadImage(imageURL string, proxy string) ([]byte, error) { + var client *http.Client + if proxy == "" { + client = http.DefaultClient + } else { + proxyURL, _ := url.Parse(proxy) + client = &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + }, + } + } + request, err := http.NewRequest("GET", imageURL, nil) + if err != nil { + return nil, err + } + + resp, err := client.Do(request) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + imageBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + return imageBytes, nil +} diff --git a/utils/openai.go b/utils/openai.go new file mode 100644 index 0000000..86a976a --- /dev/null +++ b/utils/openai.go @@ -0,0 +1,93 @@ +package utils + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "fmt" + "geekai/core/types" + "geekai/store/model" + "github.com/imroc/req/v3" + "github.com/pkoukk/tiktoken-go" + "gorm.io/gorm" + "time" +) + +func CalcTokens(text string, model string) (int, error) { + encoding, ok := tiktoken.MODEL_TO_ENCODING[model] + if !ok { + encoding = "cl100k_base" + } + tke, err := tiktoken.GetEncoding(encoding) + if err != nil { + return 0, fmt.Errorf("getEncoding: %v", err) + } + + token := tke.Encode(text, nil, nil) + return len(token), nil +} + +type apiRes struct { + Model string `json:"model"` + Choices []struct { + Index int `json:"index"` + Message struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"message"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` +} + +type apiErrRes struct { + Error struct { + Code interface{} `json:"code"` + Message string `json:"message"` + Param interface{} `json:"param"` + Type string `json:"type"` + } `json:"error"` +} + +func OpenAIRequest(db *gorm.DB, prompt string) (string, error) { + var apiKey model.ApiKey + res := db.Where("platform = ?", types.OpenAI.Value).Where("type", "chat").Where("enabled = ?", true).First(&apiKey) + if res.Error != nil { + return "", fmt.Errorf("error with fetch OpenAI API KEY:%v", res.Error) + } + + messages := make([]interface{}, 1) + messages[0] = types.Message{ + Role: "user", + Content: prompt, + } + + var response apiRes + var errRes apiErrRes + client := req.C() + if len(apiKey.ProxyURL) > 5 { + client.SetProxyURL(apiKey.ApiURL) + } + r, err := client.R().SetHeader("Content-Type", "application/json"). + SetHeader("Authorization", "Bearer "+apiKey.Value). + SetBody(types.ApiRequest{ + Model: "gpt-3.5-turbo-0125", + Temperature: 0.9, + MaxTokens: 1024, + Stream: false, + Messages: messages, + }). + SetErrorResult(&errRes). + SetSuccessResult(&response).Post(apiKey.ApiURL) + if err != nil || r.IsErrorState() { + return "", fmt.Errorf("error with http request: %v%v%s", err, r.Err, errRes.Error.Message) + } + + // 更新 API KEY 的最后使用时间 + db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix()) + + return response.Choices[0].Message.Content, nil +} diff --git a/utils/resp/response.go b/utils/resp/response.go new file mode 100644 index 0000000..3d21124 --- /dev/null +++ b/utils/resp/response.go @@ -0,0 +1,51 @@ +package resp + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "geekai/core/types" + "github.com/gin-gonic/gin" + "net/http" +) + +func SUCCESS(c *gin.Context, values ...interface{}) { + if values != nil { + c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Data: values[0]}) + } else { + c.JSON(http.StatusOK, types.BizVo{Code: types.Success}) + } + +} + +func ERROR(c *gin.Context, messages ...string) { + if messages != nil { + c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: messages[0]}) + } else { + c.JSON(http.StatusOK, types.BizVo{Code: types.Failed}) + } +} + +func HACKER(c *gin.Context) { + c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Hacker attempt!!!"}) +} + +func NotAuth(c *gin.Context, messages ...string) { + if messages != nil { + c.JSON(http.StatusOK, types.BizVo{Code: types.NotAuthorized, Message: messages[0]}) + } else { + c.JSON(http.StatusOK, types.BizVo{Code: types.NotAuthorized, Message: "Not Authorized"}) + } +} + +func NotPermission(c *gin.Context, messages ...string) { + if messages != nil { + c.JSON(http.StatusOK, types.BizVo{Code: types.NotPermission, Message: messages[0]}) + } else { + c.JSON(http.StatusOK, types.BizVo{Code: types.NotPermission, Message: "Not Permission"}) + } +} diff --git a/utils/strings.go b/utils/strings.go new file mode 100644 index 0000000..ff5c28e --- /dev/null +++ b/utils/strings.go @@ -0,0 +1,126 @@ +package utils + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "encoding/json" + "fmt" + "math/rand" + "strings" + "time" + "unicode" + + "golang.org/x/crypto/sha3" +) + +// RandString generate rand string with specified length +func RandString(length int) string { + str := "0123456789abcdefghijklmnopqrstuvwxyz" + data := []byte(str) + var result []byte + r := rand.New(rand.NewSource(time.Now().UnixNano())) + for i := 0; i < length; i++ { + result = append(result, data[r.Intn(len(data))]) + } + return string(result) +} + +func RandomNumber(bit int) int { + min := intPow(10, bit-1) + max := intPow(10, bit) - 1 + + rand.Seed(time.Now().UnixNano()) + return rand.Intn(max-min+1) + min +} + +func intPow(x, y int) int { + result := 1 + for i := 0; i < y; i++ { + result *= x + } + return result +} + +func ContainsStr(slice []string, item string) bool { + for _, e := range slice { + if e == item { + return true + } + } + return false +} + +// Stamp2str 时间戳转字符串 +func Stamp2str(timestamp int64) string { + if timestamp == 0 { + return "" + } + return time.Unix(timestamp, 0).Format("2006-01-02 15:04:05") +} + +// Str2stamp 字符串转时间戳 +func Str2stamp(str string) int64 { + if len(str) == 0 { + return 0 + } + + layout := "2006-01-02 15:04:05" + t, err := time.ParseInLocation(layout, str, time.Local) + if err != nil { + return 0 + } + return t.Unix() +} + +func GenPassword(pass string, salt string) string { + data := []byte(pass + salt) + hash := sha3.Sum256(data) + return fmt.Sprintf("%x", hash) +} + +func JsonEncode(value interface{}) string { + bytes, err := json.Marshal(value) + if err != nil { + return "" + } + return string(bytes) +} + +func JsonDecode(src string, dest interface{}) error { + return json.Unmarshal([]byte(src), dest) +} + +func InterfaceToString(value interface{}) string { + if str, ok := value.(string); ok { + return str + } + return JsonEncode(value) +} + +// CutWords 截取前 N 个单词 +func CutWords(str string, num int) string { + // 按空格分割字符串为单词切片 + words := strings.Fields(str) + + // 如果单词数量超过指定数量,则裁剪单词;否则保持原样 + if len(words) > num { + return strings.Join(words[:num], " ") + " ..." + } else { + return str + } +} + +// HasChinese 判断文本是否含有中文 +func HasChinese(text string) bool { + for _, char := range text { + if unicode.Is(unicode.Scripts["Han"], char) { + return true + } + } + return false +} diff --git a/utils/upload.go b/utils/upload.go new file mode 100644 index 0000000..5c764d8 --- /dev/null +++ b/utils/upload.go @@ -0,0 +1,101 @@ +package utils + +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +// * Copyright 2023 The Geek-AI Authors. All rights reserved. +// * Use of this source code is governed by a Apache-2.0 license +// * that can be found in the LICENSE file. +// * @Author yangjian102621@163.com +// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import ( + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "regexp" + "strings" + "time" +) + +// GenUploadPath 生成上传文件路径 +func GenUploadPath(basePath, filename string, isImg bool) (string, error) { + now := time.Now() + dir := fmt.Sprintf("%s/%d/%d", basePath, now.Year(), now.Month()) + _, err := os.Stat(dir) + if err != nil { + err = os.MkdirAll(dir, 0755) + if err != nil { + return "", fmt.Errorf("error with create upload dir:%v", err) + } + } + var fileExt string + if isImg { + fileExt = GetImgExt(filename) + } else { + fileExt = filepath.Ext(filename) + } + return fmt.Sprintf("%s/%d%s", dir, now.UnixMicro(), fileExt), nil +} + +// GenUploadUrl 生成上传文件 URL +func GenUploadUrl(basePath, baseUrl string, filePath string) string { + return strings.Replace(filePath, basePath, baseUrl, 1) +} + +func DownloadFile(fileURL string, filepath string, proxy string) error { + var client *http.Client + if proxy == "" { + client = http.DefaultClient + } else { + proxyURL, _ := url.Parse(proxy) + client = &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + }, + } + } + req, err := http.NewRequest("GET", fileURL, nil) + if err != nil { + return err + } + + resp, err := client.Do(req) + if err != nil { + return err + } + + file, err := os.Create(filepath) + if err != nil { + return err + } + defer file.Close() + + _, err = io.Copy(file, resp.Body) + if err != nil { + return err + } + + return nil +} + +func GetImgExt(filename string) string { + ext := filepath.Ext(filename) + if ext == "" { + return ".png" + } + return ext +} + +func ExtractImgURL(text string) []string { + re := regexp.MustCompile(`(http[s]?:\/\/.*?\.(?:png|jpg|jpeg|gif))`) + matches := re.FindAllStringSubmatch(text, 10) + urls := make([]string, 0) + if len(matches) > 0 { + for _, m := range matches { + urls = append(urls, m[1]) + } + } + return urls +}