diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 8e718916..6630ba02 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -153,6 +153,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { rpmCache := repository.NewRPMCache(redisClient) groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache) groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService) + riskRepository := service.NewRiskRepository(db, settingRepository, redisClient) + riskService := service.NewRiskService(riskRepository, settingRepository, redisClient) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator) adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService) dataManagementService := service.NewDataManagementService() @@ -179,9 +181,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI) digestSessionStore := service.NewDigestSessionStore() - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService) + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, riskService) openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI) - openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) + openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, riskService) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) @@ -217,7 +219,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db) scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository) scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler) + riskHandler := admin.NewRiskHandler(riskService) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, riskHandler) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) diff --git a/backend/internal/handler/admin/risk_handler.go b/backend/internal/handler/admin/risk_handler.go new file mode 100644 index 00000000..b1829558 --- /dev/null +++ b/backend/internal/handler/admin/risk_handler.go @@ -0,0 +1,114 @@ +package admin + +import ( + "strconv" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +type RiskHandler struct { + service *service.RiskService +} + +func NewRiskHandler(svc *service.RiskService) *RiskHandler { + return &RiskHandler{service: svc} +} + +func (h *RiskHandler) GetSummary(c *gin.Context) { + summary, err := h.service.GetSummary(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, summary) +} + +func (h *RiskHandler) ListAccounts(c *gin.Context) { + filter := service.RiskAccountFilter{ + Level: c.Query("risk_level"), + Platform: c.Query("platform"), + } + if p := c.Query("page"); p != "" { + if v, err := strconv.Atoi(p); err == nil { + filter.Page = v + } + } + if l := c.Query("limit"); l != "" { + if v, err := strconv.Atoi(l); err == nil { + filter.PageSize = v + } + } + + list, err := h.service.ListAccounts(c.Request.Context(), filter) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, list) +} + +func (h *RiskHandler) GetAccountDetail(c *gin.Context) { + idStr := c.Param("id") + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil || id <= 0 { + response.ErrorFrom(c, service.ErrRiskAccountNotFound) + return + } + detail, err := h.service.GetAccountDetail(c.Request.Context(), id) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, detail) +} + +type overrideRiskLevelRequest struct { + Level string `json:"level" binding:"required"` + Reason string `json:"reason" binding:"required"` +} + +func (h *RiskHandler) OverrideRiskLevel(c *gin.Context) { + idStr := c.Param("id") + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil || id <= 0 { + response.ErrorFrom(c, service.ErrRiskAccountNotFound) + return + } + + var req overrideRiskLevelRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.ErrorFrom(c, err) + return + } + + if err := h.service.OverrideRiskLevel(c.Request.Context(), id, req.Level, req.Reason); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, nil) +} + +func (h *RiskHandler) GetSettings(c *gin.Context) { + settings, err := h.service.GetSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, settings) +} + +func (h *RiskHandler) UpdateSettings(c *gin.Context) { + var req service.RiskSettings + if err := c.ShouldBindJSON(&req); err != nil { + response.ErrorFrom(c, err) + return + } + updated, err := h.service.UpdateSettings(c.Request.Context(), &req) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, updated) +} diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index b2467eac..0353a4a5 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -30,6 +30,7 @@ type AdminHandlers struct { TLSFingerprintProfile *admin.TLSFingerprintProfileHandler APIKey *admin.AdminAPIKeyHandler ScheduledTest *admin.ScheduledTestHandler + Risk *admin.RiskHandler } // Handlers contains all HTTP handlers diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 02ddd030..d1afc5da 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -33,6 +33,7 @@ func ProvideAdminHandlers( tlsFingerprintProfileHandler *admin.TLSFingerprintProfileHandler, apiKeyHandler *admin.AdminAPIKeyHandler, scheduledTestHandler *admin.ScheduledTestHandler, + riskHandler *admin.RiskHandler, ) *AdminHandlers { return &AdminHandlers{ Dashboard: dashboardHandler, @@ -59,6 +60,7 @@ func ProvideAdminHandlers( TLSFingerprintProfile: tlsFingerprintProfileHandler, APIKey: apiKeyHandler, ScheduledTest: scheduledTestHandler, + Risk: riskHandler, } } @@ -150,6 +152,7 @@ var ProviderSet = wire.NewSet( admin.NewTLSFingerprintProfileHandler, admin.NewAdminAPIKeyHandler, admin.NewScheduledTestHandler, + admin.NewRiskHandler, // AdminHandlers and Handlers constructors ProvideAdminHandlers, diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go index 1b45e507..75f2384f 100644 --- a/backend/internal/pkg/antigravity/request_transformer.go +++ b/backend/internal/pkg/antigravity/request_transformer.go @@ -262,8 +262,41 @@ func hasMCPTools(tools []ClaudeTool) bool { return false } -// filterOpenCodePrompt 过滤 OpenCode 默认提示词,只保留用户自定义指令 +// claudeCodeSignatures Claude Code / Anthropic 特征字符串,命中任意一个即视为需要过滤的 CLI 默认 prompt +var claudeCodeSignatures = []string{ + "You are Claude Code, Anthropic's official CLI", + "You are Claude Code,", + "Anthropic's official CLI", + "x-anthropic-billing-header", + "cc_entrypoint=cli", +} + +// filterClaudeCodePrompt 过滤 Claude Code 默认 system prompt,防止 Anthropic 特征暴露给上游 +// 策略:检测到特征字符串后,尝试提取用户自定义指令部分("Instructions from:" 之后),否则返回空 +func filterClaudeCodePrompt(text string) (string, bool) { + matched := false + for _, sig := range claudeCodeSignatures { + if strings.Contains(text, sig) { + matched = true + break + } + } + if !matched { + return text, false + } + // 尝试保留用户自定义指令 + if idx := strings.Index(text, "Instructions from:"); idx >= 0 { + return text[idx:], true + } + return "", true +} + +// filterOpenCodePrompt 过滤 OpenCode / Claude Code 默认提示词,只保留用户自定义指令 func filterOpenCodePrompt(text string) string { + // 优先检测 Claude Code 特征 + if filtered, matched := filterClaudeCodePrompt(text); matched { + return filtered + } if !strings.Contains(text, "You are an interactive CLI tool") { return text } diff --git a/backend/internal/pkg/antigravity/request_transformer_test.go b/backend/internal/pkg/antigravity/request_transformer_test.go index 9e46295a..4d22db7d 100644 --- a/backend/internal/pkg/antigravity/request_transformer_test.go +++ b/backend/internal/pkg/antigravity/request_transformer_test.go @@ -2,7 +2,6 @@ package antigravity import ( "encoding/json" - "strings" "testing" "github.com/stretchr/testify/require" @@ -353,7 +352,7 @@ func TestBuildGenerationConfig_ThinkingDynamicBudget(t *testing.T) { } } -func TestTransformClaudeToGeminiWithOptions_PreservesBillingHeaderSystemBlock(t *testing.T) { +func TestTransformClaudeToGeminiWithOptions_FiltersBillingHeaderSystemBlock(t *testing.T) { tests := []struct { name string system json.RawMessage @@ -388,15 +387,11 @@ func TestTransformClaudeToGeminiWithOptions_PreservesBillingHeaderSystemBlock(t require.NoError(t, json.Unmarshal(body, &req)) require.NotNil(t, req.Request.SystemInstruction) - found := false + // Claude Code / Anthropic 特征字符串不应透传给上游 for _, part := range req.Request.SystemInstruction.Parts { - if strings.Contains(part.Text, "x-anthropic-billing-header keep") { - found = true - break - } + require.NotContains(t, part.Text, "x-anthropic-billing-header", + "Claude Code 特征字符串不应透传给 Antigravity 上游") } - - require.True(t, found, "转换后的 systemInstruction 应保留 x-anthropic-billing-header 内容") }) } } diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 1513411d..a832e318 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -87,6 +87,9 @@ func RegisterAdminRoutes( // 定时测试计划 registerScheduledTestRoutes(admin, h) + + // 风控中心 + registerRiskRoutes(admin, h) } } @@ -566,3 +569,15 @@ func registerTLSFingerprintProfileRoutes(admin *gin.RouterGroup, h *handler.Hand profiles.DELETE("/:id", h.Admin.TLSFingerprintProfile.Delete) } } + +func registerRiskRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + risk := admin.Group("/risk") + { + risk.GET("/summary", h.Admin.Risk.GetSummary) + risk.GET("/accounts", h.Admin.Risk.ListAccounts) + risk.GET("/accounts/:id", h.Admin.Risk.GetAccountDetail) + risk.PUT("/accounts/:id/override", h.Admin.Risk.OverrideRiskLevel) + risk.GET("/settings", h.Admin.Risk.GetSettings) + risk.PUT("/settings", h.Admin.Risk.UpdateSettings) + } +} diff --git a/backend/internal/service/antigravity_token_provider.go b/backend/internal/service/antigravity_token_provider.go index 1b360d93..c4e53ad0 100644 --- a/backend/internal/service/antigravity_token_provider.go +++ b/backend/internal/service/antigravity_token_provider.go @@ -11,7 +11,7 @@ import ( ) const ( - antigravityTokenRefreshSkew = 3 * time.Minute + antigravityTokenRefreshSkew = 5 * time.Minute antigravityTokenCacheSkew = 5 * time.Minute antigravityBackfillCooldown = 5 * time.Minute // antigravityRequestRefreshTimeout 请求路径上 token 刷新的最大等待时间。 diff --git a/backend/internal/service/antigravity_token_refresher.go b/backend/internal/service/antigravity_token_refresher.go index 7ce0ccf0..5ddcc5c7 100644 --- a/backend/internal/service/antigravity_token_refresher.go +++ b/backend/internal/service/antigravity_token_refresher.go @@ -36,7 +36,8 @@ func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool { } // NeedsRefresh 检查账户是否需要刷新 -// Antigravity 使用固定的15分钟刷新窗口,忽略全局配置 +// Deprecated: Antigravity 已改为请求路径按需刷新,不再注册后台定时刷新器。 +// 此方法仅保留以满足 TokenRefresher 接口,不会被 TokenRefreshService 调用。 func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, _ time.Duration) bool { if !r.CanRefresh(account) { return false diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 4ae5a469..f92a4ac3 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -235,6 +235,9 @@ const ( // SettingKeyBackendModeEnabled Backend 模式:禁用用户注册和自助服务,仅管理员可登录 SettingKeyBackendModeEnabled = "backend_mode_enabled" + + // SettingKeyRiskSettings 风控系统配置 (JSON) + SettingKeyRiskSettings = "risk_settings" ) // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index ee447d2f..77d35d9f 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -565,6 +565,7 @@ type GatewayService struct { debugClaudeMimic atomic.Bool debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set tlsFPProfileService *TLSFingerprintProfileService + riskService *RiskService } // NewGatewayService creates a new GatewayService @@ -592,6 +593,7 @@ func NewGatewayService( digestStore *DigestSessionStore, settingService *SettingService, tlsFPProfileService *TLSFingerprintProfileService, + riskService *RiskService, ) *GatewayService { userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg) modelsListTTL := resolveModelsListCacheTTL(cfg) @@ -624,6 +626,7 @@ func NewGatewayService( modelsListCacheTTL: modelsListTTL, responseHeaderFilter: compileResponseHeaderFilter(cfg), tlsFPProfileService: tlsFPProfileService, + riskService: riskService, } svc.userGroupRateResolver = newUserGroupRateResolver( userGroupRateRepo, @@ -7683,6 +7686,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) + s.riskService.CollectBehaviorAsync(ctx, account, usageLog) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } @@ -7706,6 +7710,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu return billingErr } writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") + s.riskService.CollectBehaviorAsync(ctx, account, usageLog) return nil } @@ -7866,6 +7871,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) + s.riskService.CollectBehaviorAsync(ctx, account, usageLog) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } @@ -7889,6 +7895,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * return billingErr } writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") + s.riskService.CollectBehaviorAsync(ctx, account, usageLog) return nil } diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index a72a86ac..967ec43a 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -337,6 +337,7 @@ type OpenAIGatewayService struct { openaiWSRetryMetrics openAIWSRetryMetrics responseHeaderFilter *responseheaders.CompiledHeaderFilter codexSnapshotThrottle *accountWriteThrottle + riskService *RiskService } // NewOpenAIGatewayService creates a new OpenAIGatewayService @@ -357,6 +358,7 @@ func NewOpenAIGatewayService( httpUpstream HTTPUpstream, deferredService *DeferredService, openAITokenProvider *OpenAITokenProvider, + riskService *RiskService, ) *OpenAIGatewayService { svc := &OpenAIGatewayService{ accountRepo: accountRepo, @@ -386,6 +388,7 @@ func NewOpenAIGatewayService( openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), responseHeaderFilter: compileResponseHeaderFilter(cfg), codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval), + riskService: riskService, } svc.logOpenAIWSModeBootstrap() return svc @@ -4227,6 +4230,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway") logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) + s.riskService.CollectBehaviorAsync(ctx, account, usageLog) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } @@ -4250,6 +4254,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec return billingErr } writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway") + s.riskService.CollectBehaviorAsync(ctx, account, usageLog) return nil } diff --git a/backend/internal/service/risk_models.go b/backend/internal/service/risk_models.go new file mode 100644 index 00000000..d5ffe968 --- /dev/null +++ b/backend/internal/service/risk_models.go @@ -0,0 +1,121 @@ +package service + +import ( + "encoding/json" + "time" +) + +const ( + RiskLevelLow = "LOW" + RiskLevelMedium = "MEDIUM" + RiskLevelHigh = "HIGH" +) + +const ( + RiskPhaseOff = "off" + RiskPhaseObserve = "observe" + RiskPhaseEnforce = "enforce" +) + +const ( + riskSettingsCacheKey = "settings:risk:v1" +) + +type RiskSettings struct { + MediumThreshold float64 `json:"medium_threshold"` + HighThreshold float64 `json:"high_threshold"` + Phase string `json:"phase"` +} + +func DefaultRiskSettings() *RiskSettings { + return &RiskSettings{ + MediumThreshold: 0.45, + HighThreshold: 0.75, + Phase: RiskPhaseObserve, + } +} + +type RiskBehaviorHourDelta struct { + APICallCount int64 + StreamCount int64 + TotalInputTokens int64 + TotalOutputTokens int64 + TotalDurationMs int64 + P50DurationMs *int +} + +type RiskSummary struct { + TotalAccounts int64 `json:"total_accounts"` + LowCount int64 `json:"low_count"` + MediumCount int64 `json:"medium_count"` + HighCount int64 `json:"high_count"` + AverageScore float64 `json:"average_score"` + LastScoredAt *time.Time `json:"last_scored_at,omitempty"` + Settings *RiskSettings `json:"settings"` +} + +type RiskAccountFilter struct { + Page int + PageSize int + Level string + Platform string +} + +type RiskAccountListItem struct { + AccountID int64 `json:"account_id"` + AccountName string `json:"account_name"` + Platform string `json:"platform"` + RiskScore float64 `json:"risk_score"` + RiskLevel string `json:"risk_level"` + RiskReasons json.RawMessage `json:"risk_reasons"` + FeatureVector json.RawMessage `json:"feature_vector"` + IdleOverride bool `json:"idle_override"` + ScoredAt time.Time `json:"scored_at"` + LastHourCalls int64 `json:"last_hour_calls"` + LastHourTokens int64 `json:"last_hour_tokens"` +} + +type RiskAccountList struct { + Items []*RiskAccountListItem `json:"items"` + Total int64 `json:"total"` + Page int `json:"page"` + PageSize int `json:"page_size"` +} + +type RiskBehaviorHour struct { + HourBucket time.Time `json:"hour_bucket"` + APICallCount int64 `json:"api_call_count"` + StreamCount int64 `json:"stream_count"` + TotalInputTokens int64 `json:"total_input_tokens"` + TotalOutputTokens int64 `json:"total_output_tokens"` + TotalDurationMs int64 `json:"total_duration_ms"` + P50DurationMs *int `json:"p50_duration_ms,omitempty"` +} + +type RiskAccountDetail struct { + AccountID int64 `json:"account_id"` + AccountName string `json:"account_name"` + Platform string `json:"platform"` + RiskScore float64 `json:"risk_score"` + RiskLevel string `json:"risk_level"` + RiskReasons json.RawMessage `json:"risk_reasons"` + FeatureVector json.RawMessage `json:"feature_vector"` + IdleOverride bool `json:"idle_override"` + ScoredAt time.Time `json:"scored_at"` + ModelVersion int `json:"model_version"` + HourlyBehavior []RiskBehaviorHour `json:"hourly_behavior"` +} + +type RiskScoreRecord struct { + ID int64 `json:"id"` + AccountID int64 `json:"account_id"` + RiskScore float64 `json:"risk_score"` + RiskLevel string `json:"risk_level"` + RiskReasons json.RawMessage `json:"risk_reasons"` + FeatureVector json.RawMessage `json:"feature_vector"` + ScoredAt time.Time `json:"scored_at"` + ModelVersion int `json:"model_version"` + IdleOverride bool `json:"idle_override"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} diff --git a/backend/internal/service/risk_repository.go b/backend/internal/service/risk_repository.go new file mode 100644 index 00000000..efbc5afe --- /dev/null +++ b/backend/internal/service/risk_repository.go @@ -0,0 +1,494 @@ +package service + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "strings" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/redis/go-redis/v9" +) + +var ( + ErrRiskAccountNotFound = infraerrors.NotFound("RISK_ACCOUNT_NOT_FOUND", "risk account not found") + ErrRiskLevelInvalid = infraerrors.BadRequest("RISK_LEVEL_INVALID", "risk level must be LOW, MEDIUM, or HIGH") +) + +type RiskRepository interface { + UpsertBehaviorHour(ctx context.Context, accountID int64, hour time.Time, delta RiskBehaviorHourDelta) error + GetRiskSummary(ctx context.Context) (*RiskSummary, error) + ListRiskAccounts(ctx context.Context, filter RiskAccountFilter) (*RiskAccountList, error) + GetRiskAccountDetail(ctx context.Context, accountID int64) (*RiskAccountDetail, error) + OverrideRiskLevel(ctx context.Context, accountID int64, level, reason string) error + GetOrCreateRiskScore(ctx context.Context, accountID int64) (*RiskScoreRecord, error) +} + +type pgRiskRepository struct { + db *sql.DB + settingRepo SettingRepository + redis *redis.Client +} + +func NewRiskRepository(db *sql.DB, settingRepo SettingRepository, redisClient *redis.Client) RiskRepository { + return &pgRiskRepository{ + db: db, + settingRepo: settingRepo, + redis: redisClient, + } +} + +const riskUpsertBehaviorHourSQL = ` +INSERT INTO account_behavior_hourly ( + account_id, hour_bucket, api_call_count, stream_count, + total_input_tokens, total_output_tokens, total_duration_ms, p50_duration_ms, + created_at, updated_at +) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, NOW(), NOW() +) +ON CONFLICT (account_id, hour_bucket) DO UPDATE SET + api_call_count = account_behavior_hourly.api_call_count + EXCLUDED.api_call_count, + stream_count = account_behavior_hourly.stream_count + EXCLUDED.stream_count, + total_input_tokens = account_behavior_hourly.total_input_tokens + EXCLUDED.total_input_tokens, + total_output_tokens = account_behavior_hourly.total_output_tokens + EXCLUDED.total_output_tokens, + total_duration_ms = account_behavior_hourly.total_duration_ms + EXCLUDED.total_duration_ms, + p50_duration_ms = COALESCE(EXCLUDED.p50_duration_ms, account_behavior_hourly.p50_duration_ms), + updated_at = NOW() +` + +const riskSummarySQL = ` +SELECT + COUNT(rs.account_id)::bigint AS total_accounts, + COUNT(*) FILTER (WHERE rs.risk_level = 'LOW')::bigint AS low_count, + COUNT(*) FILTER (WHERE rs.risk_level = 'MEDIUM')::bigint AS medium_count, + COUNT(*) FILTER (WHERE rs.risk_level = 'HIGH')::bigint AS high_count, + COALESCE(AVG(rs.risk_score), 0)::double precision AS average_score, + MAX(rs.scored_at) AS last_scored_at +FROM account_risk_scores rs +JOIN accounts a ON a.id = rs.account_id +WHERE a.deleted_at IS NULL + AND a.type IN ('oauth', 'setup_token') +` + +const riskListSQL = ` +WITH current_hour AS ( + SELECT account_id, api_call_count, + total_input_tokens + total_output_tokens AS total_tokens + FROM account_behavior_hourly + WHERE hour_bucket = date_trunc('hour', NOW() AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' +) +SELECT + COUNT(*) OVER()::bigint AS total_count, + rs.account_id, + a.name, + a.platform, + rs.risk_score, + rs.risk_level, + rs.risk_reasons, + rs.feature_vector, + rs.idle_override, + rs.scored_at, + COALESCE(ch.api_call_count, 0)::bigint AS last_hour_calls, + COALESCE(ch.total_tokens, 0)::bigint AS last_hour_tokens +FROM account_risk_scores rs +JOIN accounts a ON a.id = rs.account_id +LEFT JOIN current_hour ch ON ch.account_id = rs.account_id +WHERE a.deleted_at IS NULL + AND a.type IN ('oauth', 'setup_token') + AND ($1 = '' OR rs.risk_level = $1) + AND ($2 = '' OR a.platform = $2) +ORDER BY rs.risk_score DESC, rs.scored_at DESC, rs.account_id DESC +LIMIT $3 OFFSET $4 +` + +const riskDetailSQL = ` +SELECT + rs.account_id, + a.name, + a.platform, + rs.risk_score, + rs.risk_level, + rs.risk_reasons, + rs.feature_vector, + rs.idle_override, + rs.scored_at, + rs.model_version, + bh.hour_bucket, + bh.api_call_count, + bh.stream_count, + bh.total_input_tokens, + bh.total_output_tokens, + bh.total_duration_ms, + bh.p50_duration_ms +FROM account_risk_scores rs +JOIN accounts a ON a.id = rs.account_id +LEFT JOIN account_behavior_hourly bh + ON bh.account_id = rs.account_id + AND bh.hour_bucket >= (date_trunc('hour', NOW() AT TIME ZONE 'UTC') AT TIME ZONE 'UTC') - INTERVAL '24 hours' +WHERE a.deleted_at IS NULL + AND a.type IN ('oauth', 'setup_token') + AND rs.account_id = $1 +ORDER BY bh.hour_bucket DESC NULLS LAST +` + +const riskOverrideSQL = ` +UPDATE account_risk_scores +SET + risk_level = $2, + idle_override = TRUE, + risk_reasons = COALESCE(risk_reasons, '{}'::jsonb) || jsonb_build_object( + 'manual_override', jsonb_build_object('level', $2, 'reason', $3, 'at', NOW()) + ), + updated_at = NOW() +WHERE account_id = $1 +RETURNING updated_at +` + +const riskScoreRefreshSQL = ` +WITH valid_account AS ( + SELECT id FROM accounts + WHERE id = $1 AND deleted_at IS NULL AND type IN ('oauth', 'setup_token') +), +behavior AS ( + SELECT + va.id AS account_id, + COALESCE(SUM(abh.api_call_count), 0)::double precision AS total_calls_24h, + COALESCE(AVG(abh.api_call_count), 0)::double precision AS calls_per_hour_24h, + COALESCE(SUM(abh.stream_count), 0)::double precision AS stream_calls_24h, + COALESCE(SUM(abh.total_input_tokens), 0)::double precision AS total_input_tokens_24h, + COALESCE( + percentile_cont(0.50) WITHIN GROUP (ORDER BY abh.p50_duration_ms) + FILTER (WHERE abh.p50_duration_ms IS NOT NULL), + 0 + )::double precision AS duration_p50_ms, + COALESCE(stddev_pop(abh.api_call_count), 0)::double precision AS hourly_entropy + FROM valid_account va + LEFT JOIN account_behavior_hourly abh + ON abh.account_id = va.id + AND abh.hour_bucket >= (date_trunc('hour', NOW() AT TIME ZONE 'UTC') AT TIME ZONE 'UTC') - INTERVAL '24 hours' + GROUP BY va.id +), +features AS ( + SELECT + b.account_id, + b.calls_per_hour_24h, + COALESCE(b.stream_calls_24h / NULLIF(b.total_calls_24h, 0), 0) AS stream_ratio_24h, + COALESCE(b.total_input_tokens_24h / NULLIF(b.total_calls_24h, 0), 0) AS token_per_request_avg, + b.duration_p50_ms, + b.hourly_entropy, + b.total_calls_24h + FROM behavior b +), +scored AS ( + SELECT + f.account_id, + LEAST(1.0, + (0.25 * LEAST(f.calls_per_hour_24h / 50.0, 1.0)) + + (0.20 * LEAST(GREATEST(1.0 - f.stream_ratio_24h, 0.0), 1.0)) + + (0.15 * LEAST(f.token_per_request_avg / 50000.0, 1.0)) + + (0.20 * LEAST(f.duration_p50_ms / 30000.0, 1.0)) + + (0.20 * LEAST(f.total_calls_24h / 500.0, 1.0)) + ) AS risk_score, + jsonb_build_object( + 'calls_per_hour_24h', ROUND(f.calls_per_hour_24h::numeric, 6), + 'stream_ratio_24h', ROUND(f.stream_ratio_24h::numeric, 6), + 'token_per_request_avg', ROUND(f.token_per_request_avg::numeric, 6), + 'duration_p50_ms', ROUND(f.duration_p50_ms::numeric, 6), + 'hourly_entropy', ROUND(f.hourly_entropy::numeric, 6), + 'total_calls_24h', ROUND(f.total_calls_24h::numeric, 6) + ) AS feature_vector, + jsonb_build_object( + 'auto', to_jsonb(array_remove(ARRAY[ + CASE WHEN f.calls_per_hour_24h >= 50 THEN 'high_calls_per_hour' END, + CASE WHEN f.stream_ratio_24h <= 0.20 THEN 'low_stream_ratio' END, + CASE WHEN f.token_per_request_avg >= 50000 THEN 'high_token_per_request' END, + CASE WHEN f.duration_p50_ms >= 30000 THEN 'high_latency_p50' END, + CASE WHEN f.total_calls_24h >= 500 THEN 'high_volume_24h' END + ], NULL)) + ) AS risk_reasons + FROM features f +) +INSERT INTO account_risk_scores ( + account_id, risk_score, risk_level, risk_reasons, feature_vector, + scored_at, model_version, idle_override, created_at, updated_at +) +SELECT + s.account_id, + s.risk_score, + CASE + WHEN s.risk_score >= $3 THEN 'HIGH' + WHEN s.risk_score >= $2 THEN 'MEDIUM' + ELSE 'LOW' + END, + s.risk_reasons, + s.feature_vector, + NOW(), 1, FALSE, NOW(), NOW() +FROM scored s +ON CONFLICT (account_id) DO UPDATE SET + risk_score = EXCLUDED.risk_score, + risk_level = CASE + WHEN account_risk_scores.idle_override THEN account_risk_scores.risk_level + ELSE EXCLUDED.risk_level + END, + risk_reasons = CASE + WHEN account_risk_scores.idle_override THEN + COALESCE(EXCLUDED.risk_reasons, '{}'::jsonb) || + CASE WHEN account_risk_scores.risk_reasons ? 'manual_override' + THEN jsonb_build_object('manual_override', account_risk_scores.risk_reasons -> 'manual_override') + ELSE '{}'::jsonb + END + ELSE EXCLUDED.risk_reasons + END, + feature_vector = EXCLUDED.feature_vector, + scored_at = EXCLUDED.scored_at, + model_version = EXCLUDED.model_version, + idle_override = account_risk_scores.idle_override, + updated_at = NOW() +RETURNING id, account_id, risk_score, risk_level, risk_reasons, feature_vector, + scored_at, model_version, idle_override, created_at, updated_at +` + +func (r *pgRiskRepository) UpsertBehaviorHour(ctx context.Context, accountID int64, hour time.Time, delta RiskBehaviorHourDelta) error { + hour = hour.UTC().Truncate(time.Hour) + _, err := r.db.ExecContext( + ctx, riskUpsertBehaviorHourSQL, + accountID, hour, + delta.APICallCount, delta.StreamCount, + delta.TotalInputTokens, delta.TotalOutputTokens, + delta.TotalDurationMs, riskNullableInt(delta.P50DurationMs), + ) + return err +} + +func (r *pgRiskRepository) GetRiskSummary(ctx context.Context) (*RiskSummary, error) { + var ( + totalAccounts int64 + lowCount int64 + mediumCount int64 + highCount int64 + averageScore float64 + lastScoredAt sql.NullTime + ) + if err := r.db.QueryRowContext(ctx, riskSummarySQL).Scan( + &totalAccounts, &lowCount, &mediumCount, &highCount, &averageScore, &lastScoredAt, + ); err != nil { + return nil, err + } + + settings, err := loadRiskSettings(ctx, r.settingRepo, r.redis) + if err != nil { + settings = DefaultRiskSettings() + } + + summary := &RiskSummary{ + TotalAccounts: totalAccounts, + LowCount: lowCount, + MediumCount: mediumCount, + HighCount: highCount, + AverageScore: averageScore, + Settings: settings, + } + if lastScoredAt.Valid { + ts := lastScoredAt.Time + summary.LastScoredAt = &ts + } + return summary, nil +} + +func (r *pgRiskRepository) ListRiskAccounts(ctx context.Context, filter RiskAccountFilter) (*RiskAccountList, error) { + level := strings.ToUpper(strings.TrimSpace(filter.Level)) + platform := strings.TrimSpace(filter.Platform) + limit := filter.PageSize + if limit <= 0 { + limit = 20 + } + page := filter.Page + if page <= 0 { + page = 1 + } + offset := (page - 1) * limit + + rows, err := r.db.QueryContext(ctx, riskListSQL, level, platform, limit, offset) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + result := &RiskAccountList{ + Items: make([]*RiskAccountListItem, 0, limit), + Page: page, + PageSize: limit, + } + for rows.Next() { + var ( + totalCount int64 + accountName string + platformName string + riskReasons []byte + featureVector []byte + item RiskAccountListItem + ) + if err := rows.Scan( + &totalCount, + &item.AccountID, &accountName, &platformName, + &item.RiskScore, &item.RiskLevel, + &riskReasons, &featureVector, + &item.IdleOverride, &item.ScoredAt, + &item.LastHourCalls, &item.LastHourTokens, + ); err != nil { + return nil, err + } + item.AccountName = accountName + item.Platform = platformName + item.RiskReasons = riskDecodeJSON(riskReasons) + item.FeatureVector = riskDecodeJSON(featureVector) + result.Total = totalCount + result.Items = append(result.Items, &item) + } + if err := rows.Err(); err != nil { + return nil, err + } + return result, nil +} + +func (r *pgRiskRepository) GetRiskAccountDetail(ctx context.Context, accountID int64) (*RiskAccountDetail, error) { + if accountID <= 0 { + return nil, ErrRiskAccountNotFound + } + if _, err := r.GetOrCreateRiskScore(ctx, accountID); err != nil { + return nil, err + } + + rows, err := r.db.QueryContext(ctx, riskDetailSQL, accountID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + var detail *RiskAccountDetail + for rows.Next() { + var ( + accountName string + platformName string + riskReasons []byte + featureVector []byte + hourBucket sql.NullTime + apiCallCount sql.NullInt64 + streamCount sql.NullInt64 + totalInputTokens sql.NullInt64 + totalOutputTokens sql.NullInt64 + totalDurationMs sql.NullInt64 + p50DurationMs sql.NullInt64 + ) + if detail == nil { + detail = &RiskAccountDetail{HourlyBehavior: make([]RiskBehaviorHour, 0, 24)} + } + if err := rows.Scan( + &detail.AccountID, &accountName, &platformName, + &detail.RiskScore, &detail.RiskLevel, + &riskReasons, &featureVector, + &detail.IdleOverride, &detail.ScoredAt, &detail.ModelVersion, + &hourBucket, &apiCallCount, &streamCount, + &totalInputTokens, &totalOutputTokens, &totalDurationMs, &p50DurationMs, + ); err != nil { + return nil, err + } + detail.AccountName = accountName + detail.Platform = platformName + detail.RiskReasons = riskDecodeJSON(riskReasons) + detail.FeatureVector = riskDecodeJSON(featureVector) + if hourBucket.Valid { + var p50 *int + if p50DurationMs.Valid { + v := int(p50DurationMs.Int64) + p50 = &v + } + detail.HourlyBehavior = append(detail.HourlyBehavior, RiskBehaviorHour{ + HourBucket: hourBucket.Time, + APICallCount: apiCallCount.Int64, + StreamCount: streamCount.Int64, + TotalInputTokens: totalInputTokens.Int64, + TotalOutputTokens: totalOutputTokens.Int64, + TotalDurationMs: totalDurationMs.Int64, + P50DurationMs: p50, + }) + } + } + if err := rows.Err(); err != nil { + return nil, err + } + if detail == nil { + return nil, ErrRiskAccountNotFound + } + return detail, nil +} + +func (r *pgRiskRepository) OverrideRiskLevel(ctx context.Context, accountID int64, level, reason string) error { + level = strings.ToUpper(strings.TrimSpace(level)) + switch level { + case RiskLevelLow, RiskLevelMedium, RiskLevelHigh: + default: + return ErrRiskLevelInvalid + } + if _, err := r.GetOrCreateRiskScore(ctx, accountID); err != nil { + return err + } + var updatedAt time.Time + err := r.db.QueryRowContext(ctx, riskOverrideSQL, accountID, level, strings.TrimSpace(reason)).Scan(&updatedAt) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return ErrRiskAccountNotFound + } + return err + } + _ = updatedAt + return nil +} + +func (r *pgRiskRepository) GetOrCreateRiskScore(ctx context.Context, accountID int64) (*RiskScoreRecord, error) { + if accountID <= 0 { + return nil, ErrRiskAccountNotFound + } + settings, err := loadRiskSettings(ctx, r.settingRepo, r.redis) + if err != nil { + settings = DefaultRiskSettings() + } + + record := &RiskScoreRecord{} + var riskReasons, featureVector []byte + err = r.db.QueryRowContext( + ctx, riskScoreRefreshSQL, + accountID, settings.MediumThreshold, settings.HighThreshold, + ).Scan( + &record.ID, &record.AccountID, + &record.RiskScore, &record.RiskLevel, + &riskReasons, &featureVector, + &record.ScoredAt, &record.ModelVersion, &record.IdleOverride, + &record.CreatedAt, &record.UpdatedAt, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrRiskAccountNotFound + } + return nil, err + } + record.RiskReasons = riskDecodeJSON(riskReasons) + record.FeatureVector = riskDecodeJSON(featureVector) + return record, nil +} + +func riskNullableInt(v *int) any { + if v == nil { + return nil + } + return *v +} + +func riskDecodeJSON(raw []byte) json.RawMessage { + if len(raw) == 0 { + return json.RawMessage(`{}`) + } + return json.RawMessage(raw) +} diff --git a/backend/internal/service/risk_service.go b/backend/internal/service/risk_service.go new file mode 100644 index 00000000..82037757 --- /dev/null +++ b/backend/internal/service/risk_service.go @@ -0,0 +1,247 @@ +package service + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "strings" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/redis/go-redis/v9" +) + +const riskSettingsCacheTTL = 5 * time.Minute + +var ( + ErrRiskOverrideReasonRequired = infraerrors.BadRequest("RISK_OVERRIDE_REASON_REQUIRED", "override reason is required") + ErrRiskSettingsInvalid = infraerrors.BadRequest("RISK_SETTINGS_INVALID", "risk settings are invalid") +) + +type RiskService struct { + repo RiskRepository + settingRepo SettingRepository + redis *redis.Client +} + +func NewRiskService(repo RiskRepository, settingRepo SettingRepository, redisClient *redis.Client) *RiskService { + return &RiskService{ + repo: repo, + settingRepo: settingRepo, + redis: redisClient, + } +} + +func (s *RiskService) CollectBehaviorAsync(ctx context.Context, account *Account, usageLog *UsageLog) { + if s == nil || s.repo == nil || account == nil || usageLog == nil { + return + } + if !account.IsOAuth() { + return + } + + go func() { + bg, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + createdAt := usageLog.CreatedAt + if createdAt.IsZero() { + createdAt = time.Now() + } + + delta := RiskBehaviorHourDelta{ + APICallCount: 1, + StreamCount: riskBoolToInt64(usageLog.Stream), + TotalInputTokens: int64(usageLog.InputTokens), + TotalOutputTokens: int64(usageLog.OutputTokens), + TotalDurationMs: riskIntPtrToInt64(usageLog.DurationMs), + P50DurationMs: usageLog.DurationMs, + } + + if err := s.repo.UpsertBehaviorHour(bg, usageLog.AccountID, createdAt, delta); err != nil { + slog.Warn("risk behavior upsert failed", "account_id", usageLog.AccountID, "error", err) + return + } + + settings, err := loadRiskSettings(bg, s.settingRepo, s.redis) + if err != nil { + settings = DefaultRiskSettings() + } + if settings.Phase == RiskPhaseOff { + return + } + + if _, err := s.repo.GetOrCreateRiskScore(bg, usageLog.AccountID); err != nil { + slog.Warn("risk score refresh failed", "account_id", usageLog.AccountID, "error", err) + } + }() +} + +func (s *RiskService) GetSummary(ctx context.Context) (*RiskSummary, error) { + if s == nil || s.repo == nil { + return nil, fmt.Errorf("risk service not initialized") + } + return s.repo.GetRiskSummary(ctx) +} + +func (s *RiskService) ListAccounts(ctx context.Context, filter RiskAccountFilter) (*RiskAccountList, error) { + if s == nil || s.repo == nil { + return nil, fmt.Errorf("risk service not initialized") + } + if filter.Page <= 0 { + filter.Page = 1 + } + if filter.PageSize <= 0 { + filter.PageSize = 20 + } + if filter.PageSize > 200 { + filter.PageSize = 200 + } + filter.Level = strings.ToUpper(strings.TrimSpace(filter.Level)) + filter.Platform = strings.TrimSpace(filter.Platform) + return s.repo.ListRiskAccounts(ctx, filter) +} + +func (s *RiskService) GetAccountDetail(ctx context.Context, accountID int64) (*RiskAccountDetail, error) { + if s == nil || s.repo == nil { + return nil, fmt.Errorf("risk service not initialized") + } + if accountID <= 0 { + return nil, ErrRiskAccountNotFound + } + return s.repo.GetRiskAccountDetail(ctx, accountID) +} + +func (s *RiskService) OverrideRiskLevel(ctx context.Context, accountID int64, level, reason string) error { + if s == nil || s.repo == nil { + return fmt.Errorf("risk service not initialized") + } + if accountID <= 0 { + return ErrRiskAccountNotFound + } + level = strings.ToUpper(strings.TrimSpace(level)) + reason = strings.TrimSpace(reason) + if reason == "" { + return ErrRiskOverrideReasonRequired + } + switch level { + case RiskLevelLow, RiskLevelMedium, RiskLevelHigh: + default: + return ErrRiskLevelInvalid + } + return s.repo.OverrideRiskLevel(ctx, accountID, level, reason) +} + +func (s *RiskService) GetSettings(ctx context.Context) (*RiskSettings, error) { + if s == nil || s.settingRepo == nil { + return DefaultRiskSettings(), nil + } + return loadRiskSettings(ctx, s.settingRepo, s.redis) +} + +func (s *RiskService) UpdateSettings(ctx context.Context, settings *RiskSettings) (*RiskSettings, error) { + if s == nil || s.settingRepo == nil { + return nil, fmt.Errorf("risk service not initialized") + } + normalized, err := normalizeRiskSettings(settings) + if err != nil { + return nil, err + } + data, err := json.Marshal(normalized) + if err != nil { + return nil, err + } + if err := s.settingRepo.Set(ctx, SettingKeyRiskSettings, string(data)); err != nil { + return nil, err + } + if s.redis != nil { + _ = s.redis.Del(ctx, riskSettingsCacheKey).Err() + } + return normalized, nil +} + +func loadRiskSettings(ctx context.Context, settingRepo SettingRepository, redisClient *redis.Client) (*RiskSettings, error) { + if ctx == nil { + ctx = context.Background() + } + + if redisClient != nil { + if raw, err := redisClient.Get(ctx, riskSettingsCacheKey).Result(); err == nil && strings.TrimSpace(raw) != "" { + settings := DefaultRiskSettings() + if err := json.Unmarshal([]byte(raw), settings); err == nil { + if normalized, err := normalizeRiskSettings(settings); err == nil { + return normalized, nil + } + } + } + } + + settings := DefaultRiskSettings() + if settingRepo != nil { + if raw, err := settingRepo.GetValue(ctx, SettingKeyRiskSettings); err == nil && strings.TrimSpace(raw) != "" { + if unmarshalErr := json.Unmarshal([]byte(raw), settings); unmarshalErr != nil { + slog.Warn("risk settings json invalid; using defaults", "error", unmarshalErr) + settings = DefaultRiskSettings() + } + } + } + + normalized, err := normalizeRiskSettings(settings) + if err != nil { + normalized = DefaultRiskSettings() + } + + if redisClient != nil { + if data, marshalErr := json.Marshal(normalized); marshalErr == nil { + _ = redisClient.Set(ctx, riskSettingsCacheKey, string(data), riskSettingsCacheTTL).Err() + } + } + return normalized, nil +} + +func normalizeRiskSettings(settings *RiskSettings) (*RiskSettings, error) { + if settings == nil { + return DefaultRiskSettings(), nil + } + out := &RiskSettings{ + MediumThreshold: settings.MediumThreshold, + HighThreshold: settings.HighThreshold, + Phase: strings.ToLower(strings.TrimSpace(settings.Phase)), + } + if out.MediumThreshold == 0 && out.HighThreshold == 0 && out.Phase == "" { + return DefaultRiskSettings(), nil + } + if out.Phase == "" { + out.Phase = RiskPhaseObserve + } + if out.MediumThreshold < 0 || out.MediumThreshold > 1 { + return nil, ErrRiskSettingsInvalid.WithCause(fmt.Errorf("medium_threshold must be between 0 and 1")) + } + if out.HighThreshold < 0 || out.HighThreshold > 1 { + return nil, ErrRiskSettingsInvalid.WithCause(fmt.Errorf("high_threshold must be between 0 and 1")) + } + if out.MediumThreshold >= out.HighThreshold { + return nil, ErrRiskSettingsInvalid.WithCause(fmt.Errorf("medium_threshold must be less than high_threshold")) + } + switch out.Phase { + case RiskPhaseOff, RiskPhaseObserve, RiskPhaseEnforce: + default: + return nil, ErrRiskSettingsInvalid.WithCause(fmt.Errorf("phase must be one of: %s, %s, %s", RiskPhaseOff, RiskPhaseObserve, RiskPhaseEnforce)) + } + return out, nil +} + +func riskBoolToInt64(v bool) int64 { + if v { + return 1 + } + return 0 +} + +func riskIntPtrToInt64(v *int) int64 { + if v == nil { + return 0 + } + return int64(*v) +} diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index 4fa2fe97..6166e1f8 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -63,14 +63,15 @@ func NewTokenRefreshService( claudeRefresher := NewClaudeTokenRefresher(oauthService) geminiRefresher := NewGeminiTokenRefresher(geminiOAuthService) - agRefresher := NewAntigravityTokenRefresher(antigravityOAuthService) + // Antigravity 使用请求路径按需刷新(GetAccessToken 内部处理),不注册后台定时刷新器。 + // 后台定时刷新会导致 idle 账号每天产生 ~48 次无效 OAuth 请求,触发风控封号。 + _ = antigravityOAuthService // 保留参数引用,避免编译错误 // 注册平台特定的刷新器(TokenRefresher 接口) s.refreshers = []TokenRefresher{ claudeRefresher, openAIRefresher, geminiRefresher, - agRefresher, } // 注册对应的 OAuthRefreshExecutor(带 CacheKey 方法) @@ -78,7 +79,6 @@ func NewTokenRefreshService( claudeRefresher, openAIRefresher, geminiRefresher, - agRefresher, } return s diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index d79a3531..d858c72f 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -490,4 +490,6 @@ var ProviderSet = wire.NewSet( ProvideScheduledTestService, ProvideScheduledTestRunnerService, NewGroupCapacityService, + NewRiskRepository, + NewRiskService, ) diff --git a/backend/migrations/081_create_risk_tables.sql b/backend/migrations/081_create_risk_tables.sql new file mode 100644 index 00000000..283d6919 --- /dev/null +++ b/backend/migrations/081_create_risk_tables.sql @@ -0,0 +1,49 @@ +-- +migrate Up + +CREATE TABLE IF NOT EXISTS account_behavior_hourly ( + id BIGSERIAL PRIMARY KEY, + account_id BIGINT NOT NULL, + hour_bucket TIMESTAMPTZ NOT NULL, + api_call_count BIGINT NOT NULL DEFAULT 0, + stream_count BIGINT NOT NULL DEFAULT 0, + total_input_tokens BIGINT NOT NULL DEFAULT 0, + total_output_tokens BIGINT NOT NULL DEFAULT 0, + total_duration_ms BIGINT NOT NULL DEFAULT 0, + p50_duration_ms INT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT uq_account_behavior_hourly UNIQUE (account_id, hour_bucket) +); + +CREATE INDEX IF NOT EXISTS idx_account_behavior_hourly_account_id ON account_behavior_hourly (account_id); +CREATE INDEX IF NOT EXISTS idx_account_behavior_hourly_hour_bucket ON account_behavior_hourly (hour_bucket DESC); + +CREATE TABLE IF NOT EXISTS account_risk_scores ( + id BIGSERIAL PRIMARY KEY, + account_id BIGINT NOT NULL, + risk_score DOUBLE PRECISION NOT NULL DEFAULT 0, + risk_level VARCHAR(16) NOT NULL DEFAULT 'LOW', + risk_reasons JSONB NOT NULL DEFAULT '{}', + feature_vector JSONB NOT NULL DEFAULT '{}', + scored_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + model_version INT NOT NULL DEFAULT 1, + idle_override BOOLEAN NOT NULL DEFAULT FALSE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT uq_account_risk_scores_account_id UNIQUE (account_id) +); + +CREATE INDEX IF NOT EXISTS idx_account_risk_scores_risk_level ON account_risk_scores (risk_level); +CREATE INDEX IF NOT EXISTS idx_account_risk_scores_risk_score ON account_risk_scores (risk_score DESC); +CREATE INDEX IF NOT EXISTS idx_account_risk_scores_scored_at ON account_risk_scores (scored_at DESC); + +-- +migrate Down + +DROP INDEX IF EXISTS idx_account_risk_scores_scored_at; +DROP INDEX IF EXISTS idx_account_risk_scores_risk_score; +DROP INDEX IF EXISTS idx_account_risk_scores_risk_level; +DROP TABLE IF EXISTS account_risk_scores; + +DROP INDEX IF EXISTS idx_account_behavior_hourly_hour_bucket; +DROP INDEX IF EXISTS idx_account_behavior_hourly_account_id; +DROP TABLE IF EXISTS account_behavior_hourly; diff --git a/frontend/src/api/admin/risk.ts b/frontend/src/api/admin/risk.ts new file mode 100644 index 00000000..f34f2291 --- /dev/null +++ b/frontend/src/api/admin/risk.ts @@ -0,0 +1,99 @@ +import { apiClient } from '../client' + +export interface RiskSummary { + total_monitored: number + high_risk_count: number + medium_risk_count: number + low_risk_count: number + blocked_count: number + avg_score: number +} + +export interface RiskAccountListItem { + account_id: number + email: string + platform: string + risk_level: string + risk_score: number + scored_at: string + is_overridden: boolean +} + +export interface RiskAccountList { + items: RiskAccountListItem[] + total: number + page: number + page_size: number +} + +export interface RiskBehaviorHour { + hour_bucket: string + request_count: number + token_count: number + error_count: number +} + +export interface RiskAccountDetail { + account_id: number + email: string + platform: string + risk_level: string + risk_score: number + scored_at: string + is_overridden: boolean + override_reason: string + overridden_at: string + behavior_24h: RiskBehaviorHour[] +} + +export interface RiskSettings { + enabled: boolean + phase: string + medium_threshold: number + high_threshold: number +} + +export interface RiskAccountFilter { + page?: number + limit?: number + risk_level?: string + platform?: string +} + +export async function getRiskSummary(): Promise { + const res = await apiClient.get('/admin/risk/summary') + return res.data.data +} + +export async function listRiskAccounts(filter: RiskAccountFilter = {}): Promise { + const params: Record = {} + if (filter.page) params['page'] = String(filter.page) + if (filter.limit) params['limit'] = String(filter.limit) + if (filter.risk_level) params['risk_level'] = filter.risk_level + if (filter.platform) params['platform'] = filter.platform + const res = await apiClient.get('/admin/risk/accounts', { params }) + return res.data.data +} + +export async function getRiskAccountDetail(id: number): Promise { + const res = await apiClient.get(`/admin/risk/accounts/${id}`) + return res.data.data +} + +export async function overrideRiskLevel( + id: number, + level: string, + reason: string +): Promise { + await apiClient.put(`/admin/risk/accounts/${id}/override`, { level, reason }) +} + +export async function getRiskSettings(): Promise { + const res = await apiClient.get('/admin/risk/settings') + return res.data.data +} + +export async function updateRiskSettings(settings: RiskSettings): Promise { + const res = await apiClient.put('/admin/risk/settings', settings) + return res.data.data +} diff --git a/frontend/src/components/admin/risk/RiskAccountDrawer.vue b/frontend/src/components/admin/risk/RiskAccountDrawer.vue new file mode 100644 index 00000000..45d11d1b --- /dev/null +++ b/frontend/src/components/admin/risk/RiskAccountDrawer.vue @@ -0,0 +1,195 @@ + + + + + diff --git a/frontend/src/components/admin/risk/RiskDistributionChart.vue b/frontend/src/components/admin/risk/RiskDistributionChart.vue new file mode 100644 index 00000000..42181781 --- /dev/null +++ b/frontend/src/components/admin/risk/RiskDistributionChart.vue @@ -0,0 +1,64 @@ + + + diff --git a/frontend/src/components/admin/risk/RiskRadarChart.vue b/frontend/src/components/admin/risk/RiskRadarChart.vue new file mode 100644 index 00000000..fd68c665 --- /dev/null +++ b/frontend/src/components/admin/risk/RiskRadarChart.vue @@ -0,0 +1,90 @@ + + + diff --git a/frontend/src/components/admin/risk/RiskSummaryCards.vue b/frontend/src/components/admin/risk/RiskSummaryCards.vue new file mode 100644 index 00000000..7ead42d3 --- /dev/null +++ b/frontend/src/components/admin/risk/RiskSummaryCards.vue @@ -0,0 +1,36 @@ + + + diff --git a/frontend/src/components/admin/risk/RiskSystemStatusCard.vue b/frontend/src/components/admin/risk/RiskSystemStatusCard.vue new file mode 100644 index 00000000..67e452ec --- /dev/null +++ b/frontend/src/components/admin/risk/RiskSystemStatusCard.vue @@ -0,0 +1,101 @@ + + + diff --git a/frontend/src/components/admin/risk/RiskTrendChart.vue b/frontend/src/components/admin/risk/RiskTrendChart.vue new file mode 100644 index 00000000..30f5ef66 --- /dev/null +++ b/frontend/src/components/admin/risk/RiskTrendChart.vue @@ -0,0 +1,62 @@ + + + diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue index 2e5babeb..8cd2d6e5 100644 --- a/frontend/src/components/layout/AppSidebar.vue +++ b/frontend/src/components/layout/AppSidebar.vue @@ -482,6 +482,22 @@ const ChevronDoubleRightIcon = { ) } + +const ShieldExclamationIcon = { + render: () => + h( + 'svg', + { fill: 'none', viewBox: '0 0 24 24', stroke: 'currentColor', 'stroke-width': '1.5' }, + [ + h('path', { + 'stroke-linecap': 'round', + 'stroke-linejoin': 'round', + d: 'M12 9v3.75m0-10.036A11.959 11.959 0 013.598 6 11.99 11.99 0 003 9.75c0 5.592 3.824 10.29 9 11.622 5.176-1.332 9-6.03 9-11.622 0-1.31-.21-2.57-.598-3.75h-.152c-3.196 0-6.1-1.249-8.25-3.286zm0 13.036h.008v.008H12v-.008z' + }) + ] + ) +} + // User navigation items (for regular users) const userNavItems = computed((): NavItem[] => { const items: NavItem[] = [ @@ -574,7 +590,8 @@ const adminNavItems = computed((): NavItem[] => { { path: '/admin/proxies', label: t('nav.proxies'), icon: ServerIcon }, { path: '/admin/redeem', label: t('nav.redeemCodes'), icon: TicketIcon, hideInSimpleMode: true }, { path: '/admin/promo-codes', label: t('nav.promoCodes'), icon: GiftIcon, hideInSimpleMode: true }, - { path: '/admin/usage', label: t('nav.usage'), icon: ChartIcon } + { path: '/admin/usage', label: t('nav.usage'), icon: ChartIcon }, + { path: '/admin/risk', label: 'Risk Control', icon: ShieldExclamationIcon } ] // 简单模式下,在系统设置前插入 API密钥 diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index 0ffef1a3..99c68ef5 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -375,6 +375,17 @@ const routes: RouteRecordRaw[] = [ } }, + { + path: '/admin/risk', + name: 'AdminRiskControl', + component: () => import('@/views/admin/RiskControlView.vue'), + meta: { + requiresAuth: true, + requiresAdmin: true, + title: 'Risk Control' + } + }, + // ==================== 404 Not Found ==================== { path: '/:pathMatch(.*)*', diff --git a/frontend/src/stores/risk.ts b/frontend/src/stores/risk.ts new file mode 100644 index 00000000..1dadf038 --- /dev/null +++ b/frontend/src/stores/risk.ts @@ -0,0 +1,117 @@ +import { defineStore } from 'pinia' +import { ref } from 'vue' +import { + getRiskSummary, + listRiskAccounts, + getRiskAccountDetail, + getRiskSettings, + updateRiskSettings, + overrideRiskLevel +} from '@/api/admin/risk' +import type { + RiskSummary, + RiskAccountList, + RiskAccountDetail, + RiskSettings, + RiskAccountFilter +} from '@/api/admin/risk' + +export const useRiskStore = defineStore('risk', () => { + const summary = ref(null) + const accounts = ref(null) + const accountDetail = ref(null) + const settings = ref(null) + const loading = ref(false) + const error = ref(null) + + async function fetchSummary() { + loading.value = true + error.value = null + try { + summary.value = await getRiskSummary() + } catch (e: any) { + error.value = e?.message ?? 'Failed to load summary' + } finally { + loading.value = false + } + } + + async function fetchAccounts(filter: RiskAccountFilter = {}) { + loading.value = true + error.value = null + try { + accounts.value = await listRiskAccounts(filter) + } catch (e: any) { + error.value = e?.message ?? 'Failed to load accounts' + } finally { + loading.value = false + } + } + + async function fetchAccountDetail(id: number) { + loading.value = true + error.value = null + try { + accountDetail.value = await getRiskAccountDetail(id) + } catch (e: any) { + error.value = e?.message ?? 'Failed to load account detail' + } finally { + loading.value = false + } + } + + async function fetchSettings() { + loading.value = true + error.value = null + try { + settings.value = await getRiskSettings() + } catch (e: any) { + error.value = e?.message ?? 'Failed to load settings' + } finally { + loading.value = false + } + } + + async function saveSettings(updated: RiskSettings) { + loading.value = true + error.value = null + try { + settings.value = await updateRiskSettings(updated) + return true + } catch (e: any) { + error.value = e?.message ?? 'Failed to save settings' + return false + } finally { + loading.value = false + } + } + + async function overrideAccount(id: number, level: string, reason: string) { + loading.value = true + error.value = null + try { + await overrideRiskLevel(id, level, reason) + return true + } catch (e: any) { + error.value = e?.message ?? 'Failed to override risk level' + return false + } finally { + loading.value = false + } + } + + return { + summary, + accounts, + accountDetail, + settings, + loading, + error, + fetchSummary, + fetchAccounts, + fetchAccountDetail, + fetchSettings, + saveSettings, + overrideAccount + } +}) diff --git a/frontend/src/views/admin/RiskControlView.vue b/frontend/src/views/admin/RiskControlView.vue new file mode 100644 index 00000000..5fb6c0b2 --- /dev/null +++ b/frontend/src/views/admin/RiskControlView.vue @@ -0,0 +1,198 @@ + + +