sub2api/backend/internal/service/risk_service.go
win f25dd04e0b
Some checks failed
CI / test (push) Failing after 1m31s
CI / golangci-lint (push) Failing after 3s
Security Scan / backend-security (push) Failing after 3s
Security Scan / frontend-security (push) Failing after 2s
feat(risk): 风控数据管道与风控中心
- DB Migration 081: 新增 account_behavior_hourly / account_risk_scores 表
- 行为采集:Gateway/OpenAI Gateway RecordUsage 注入 fire-and-forget CollectBehaviorAsync
- SQL 打分引擎:CTE 加权特征向量 → risk_score [0-1],UPSERT 保留 idle_override
- RiskSettings:Redis 缓存 → DB fallback → 默认值(observe 模式)
- REST API:/admin/risk/summary|accounts|accounts/:id|settings
- 前端:Pinia store + RiskControlView + 6 子组件(donut/radar/line 纯 SVG 图表)
- 侧边栏新增 Risk Control 入口(ShieldExclamationIcon)
- 反风控优化:移除 Antigravity 后台定时刷新,改为按需刷新避免 idle 封号
2026-03-28 03:07:17 +08:00

248 lines
7.0 KiB
Go

package service
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/redis/go-redis/v9"
)
const riskSettingsCacheTTL = 5 * time.Minute
var (
ErrRiskOverrideReasonRequired = infraerrors.BadRequest("RISK_OVERRIDE_REASON_REQUIRED", "override reason is required")
ErrRiskSettingsInvalid = infraerrors.BadRequest("RISK_SETTINGS_INVALID", "risk settings are invalid")
)
type RiskService struct {
repo RiskRepository
settingRepo SettingRepository
redis *redis.Client
}
func NewRiskService(repo RiskRepository, settingRepo SettingRepository, redisClient *redis.Client) *RiskService {
return &RiskService{
repo: repo,
settingRepo: settingRepo,
redis: redisClient,
}
}
func (s *RiskService) CollectBehaviorAsync(ctx context.Context, account *Account, usageLog *UsageLog) {
if s == nil || s.repo == nil || account == nil || usageLog == nil {
return
}
if !account.IsOAuth() {
return
}
go func() {
bg, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
createdAt := usageLog.CreatedAt
if createdAt.IsZero() {
createdAt = time.Now()
}
delta := RiskBehaviorHourDelta{
APICallCount: 1,
StreamCount: riskBoolToInt64(usageLog.Stream),
TotalInputTokens: int64(usageLog.InputTokens),
TotalOutputTokens: int64(usageLog.OutputTokens),
TotalDurationMs: riskIntPtrToInt64(usageLog.DurationMs),
P50DurationMs: usageLog.DurationMs,
}
if err := s.repo.UpsertBehaviorHour(bg, usageLog.AccountID, createdAt, delta); err != nil {
slog.Warn("risk behavior upsert failed", "account_id", usageLog.AccountID, "error", err)
return
}
settings, err := loadRiskSettings(bg, s.settingRepo, s.redis)
if err != nil {
settings = DefaultRiskSettings()
}
if settings.Phase == RiskPhaseOff {
return
}
if _, err := s.repo.GetOrCreateRiskScore(bg, usageLog.AccountID); err != nil {
slog.Warn("risk score refresh failed", "account_id", usageLog.AccountID, "error", err)
}
}()
}
func (s *RiskService) GetSummary(ctx context.Context) (*RiskSummary, error) {
if s == nil || s.repo == nil {
return nil, fmt.Errorf("risk service not initialized")
}
return s.repo.GetRiskSummary(ctx)
}
func (s *RiskService) ListAccounts(ctx context.Context, filter RiskAccountFilter) (*RiskAccountList, error) {
if s == nil || s.repo == nil {
return nil, fmt.Errorf("risk service not initialized")
}
if filter.Page <= 0 {
filter.Page = 1
}
if filter.PageSize <= 0 {
filter.PageSize = 20
}
if filter.PageSize > 200 {
filter.PageSize = 200
}
filter.Level = strings.ToUpper(strings.TrimSpace(filter.Level))
filter.Platform = strings.TrimSpace(filter.Platform)
return s.repo.ListRiskAccounts(ctx, filter)
}
func (s *RiskService) GetAccountDetail(ctx context.Context, accountID int64) (*RiskAccountDetail, error) {
if s == nil || s.repo == nil {
return nil, fmt.Errorf("risk service not initialized")
}
if accountID <= 0 {
return nil, ErrRiskAccountNotFound
}
return s.repo.GetRiskAccountDetail(ctx, accountID)
}
func (s *RiskService) OverrideRiskLevel(ctx context.Context, accountID int64, level, reason string) error {
if s == nil || s.repo == nil {
return fmt.Errorf("risk service not initialized")
}
if accountID <= 0 {
return ErrRiskAccountNotFound
}
level = strings.ToUpper(strings.TrimSpace(level))
reason = strings.TrimSpace(reason)
if reason == "" {
return ErrRiskOverrideReasonRequired
}
switch level {
case RiskLevelLow, RiskLevelMedium, RiskLevelHigh:
default:
return ErrRiskLevelInvalid
}
return s.repo.OverrideRiskLevel(ctx, accountID, level, reason)
}
func (s *RiskService) GetSettings(ctx context.Context) (*RiskSettings, error) {
if s == nil || s.settingRepo == nil {
return DefaultRiskSettings(), nil
}
return loadRiskSettings(ctx, s.settingRepo, s.redis)
}
func (s *RiskService) UpdateSettings(ctx context.Context, settings *RiskSettings) (*RiskSettings, error) {
if s == nil || s.settingRepo == nil {
return nil, fmt.Errorf("risk service not initialized")
}
normalized, err := normalizeRiskSettings(settings)
if err != nil {
return nil, err
}
data, err := json.Marshal(normalized)
if err != nil {
return nil, err
}
if err := s.settingRepo.Set(ctx, SettingKeyRiskSettings, string(data)); err != nil {
return nil, err
}
if s.redis != nil {
_ = s.redis.Del(ctx, riskSettingsCacheKey).Err()
}
return normalized, nil
}
func loadRiskSettings(ctx context.Context, settingRepo SettingRepository, redisClient *redis.Client) (*RiskSettings, error) {
if ctx == nil {
ctx = context.Background()
}
if redisClient != nil {
if raw, err := redisClient.Get(ctx, riskSettingsCacheKey).Result(); err == nil && strings.TrimSpace(raw) != "" {
settings := DefaultRiskSettings()
if err := json.Unmarshal([]byte(raw), settings); err == nil {
if normalized, err := normalizeRiskSettings(settings); err == nil {
return normalized, nil
}
}
}
}
settings := DefaultRiskSettings()
if settingRepo != nil {
if raw, err := settingRepo.GetValue(ctx, SettingKeyRiskSettings); err == nil && strings.TrimSpace(raw) != "" {
if unmarshalErr := json.Unmarshal([]byte(raw), settings); unmarshalErr != nil {
slog.Warn("risk settings json invalid; using defaults", "error", unmarshalErr)
settings = DefaultRiskSettings()
}
}
}
normalized, err := normalizeRiskSettings(settings)
if err != nil {
normalized = DefaultRiskSettings()
}
if redisClient != nil {
if data, marshalErr := json.Marshal(normalized); marshalErr == nil {
_ = redisClient.Set(ctx, riskSettingsCacheKey, string(data), riskSettingsCacheTTL).Err()
}
}
return normalized, nil
}
func normalizeRiskSettings(settings *RiskSettings) (*RiskSettings, error) {
if settings == nil {
return DefaultRiskSettings(), nil
}
out := &RiskSettings{
MediumThreshold: settings.MediumThreshold,
HighThreshold: settings.HighThreshold,
Phase: strings.ToLower(strings.TrimSpace(settings.Phase)),
}
if out.MediumThreshold == 0 && out.HighThreshold == 0 && out.Phase == "" {
return DefaultRiskSettings(), nil
}
if out.Phase == "" {
out.Phase = RiskPhaseObserve
}
if out.MediumThreshold < 0 || out.MediumThreshold > 1 {
return nil, ErrRiskSettingsInvalid.WithCause(fmt.Errorf("medium_threshold must be between 0 and 1"))
}
if out.HighThreshold < 0 || out.HighThreshold > 1 {
return nil, ErrRiskSettingsInvalid.WithCause(fmt.Errorf("high_threshold must be between 0 and 1"))
}
if out.MediumThreshold >= out.HighThreshold {
return nil, ErrRiskSettingsInvalid.WithCause(fmt.Errorf("medium_threshold must be less than high_threshold"))
}
switch out.Phase {
case RiskPhaseOff, RiskPhaseObserve, RiskPhaseEnforce:
default:
return nil, ErrRiskSettingsInvalid.WithCause(fmt.Errorf("phase must be one of: %s, %s, %s", RiskPhaseOff, RiskPhaseObserve, RiskPhaseEnforce))
}
return out, nil
}
func riskBoolToInt64(v bool) int64 {
if v {
return 1
}
return 0
}
func riskIntPtrToInt64(v *int) int64 {
if v == nil {
return 0
}
return int64(*v)
}