Compare commits
26 Commits
b5642bd068
...
1dfd974432
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1dfd974432 | ||
|
|
cc396f59cf | ||
|
|
aa8b9cc508 | ||
|
|
6a2cf09ee0 | ||
|
|
c6fd88116b | ||
|
|
8f0dbdeaba | ||
|
|
007c09b84e | ||
|
|
73f3c068ef | ||
|
|
9a92fa4a60 | ||
|
|
576af710be | ||
|
|
f2c2abe628 | ||
|
|
ff5b467fbe | ||
|
|
8c10941142 | ||
|
|
f5764d8dc6 | ||
|
|
81ca4f12dd | ||
|
|
941c469ab9 | ||
|
|
8fcd819e6f | ||
|
|
9abdaed20c | ||
|
|
eb94342f78 | ||
|
|
d563eb2336 | ||
|
|
3ee6f085db | ||
|
|
7cca69a136 | ||
|
|
093a5a260e | ||
|
|
2c072c0ed6 | ||
|
|
1f39bf8a78 | ||
|
|
c7f4a649df |
@ -49,9 +49,13 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td width="180" align="center" valign="middle"><a href="https://shop.pincc.ai/"><img src="assets/partners/logos/pincc-logo.png" alt="pincc" width="120"></a></td>
|
||||
<td width="180" align="center" valign="middle"><a href="https://shop.pincc.ai/"><img src="assets/partners/logos/pincc-logo.png" alt="pincc" width="150"></a></td>
|
||||
<td valign="middle"><b><a href="https://shop.pincc.ai/">PinCC</a></b> is the official relay service built on Sub2API, offering stable access to Claude Code, Codex, Gemini and other popular models — ready to use, no deployment or maintenance required.</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://www.packyapi.com/register?aff=sub2api"><img src="assets/partners/logos/packycode.png" alt="PackyCode" width="150"></a></td>
|
||||
<td>Thanks to PackyCode for sponsoring this project! PackyCode is a reliable and efficient API relay service provider, offering relay services for Claude Code, Codex, Gemini, and more. PackyCode provides special discounts for our software users: register using <a href="https://www.packyapi.com/register?aff=sub2api">this link</a> and enter the "sub2api" promo code during first recharge to get 10% off.</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Ecosystem
|
||||
|
||||
@ -48,9 +48,13 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td width="180" align="center" valign="middle"><a href="https://shop.pincc.ai/"><img src="assets/partners/logos/pincc-logo.png" alt="pincc" width="120"></a></td>
|
||||
<td width="180" align="center" valign="middle"><a href="https://shop.pincc.ai/"><img src="assets/partners/logos/pincc-logo.png" alt="pincc" width="150"></a></td>
|
||||
<td valign="middle"><b><a href="https://shop.pincc.ai/">PinCC</a></b> 是基于 Sub2API 搭建的官方中转服务,提供 Claude Code、Codex、Gemini 等主流模型的稳定中转,开箱即用,免去自建部署与运维烦恼。</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://www.packyapi.com/register?aff=sub2api"><img src="assets/partners/logos/packycode.png" alt="PackyCode" width="150"></a></td>
|
||||
<td>感谢 PackyCode 赞助了本项目!PackyCode 是一家稳定、高效的API中转服务商,提供 Claude Code、Codex、Gemini 等多种中转服务。PackyCode 为本软件的用户提供了特别优惠,使用<a href="https://www.packyapi.com/register?aff=sub2api">此链接</a>注册并在充值时填写"sub2api"优惠码,首次充值可以享受9折优惠!</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## 生态项目
|
||||
|
||||
@ -49,9 +49,13 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td width="180" align="center" valign="middle"><a href="https://shop.pincc.ai/"><img src="assets/partners/logos/pincc-logo.png" alt="pincc" width="120"></a></td>
|
||||
<td width="180" align="center" valign="middle"><a href="https://shop.pincc.ai/"><img src="assets/partners/logos/pincc-logo.png" alt="pincc" width="150"></a></td>
|
||||
<td valign="middle"><b><a href="https://shop.pincc.ai/">PinCC</a></b> は Sub2API 上に構築された公式リレーサービスで、Claude Code、Codex、Gemini などの人気モデルへの安定したアクセスを提供します。デプロイやメンテナンスは不要で、すぐにご利用いただけます。</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://www.packyapi.com/register?aff=sub2api"><img src="assets/partners/logos/packycode.png" alt="PackyCode" width="150"></a></td>
|
||||
<td>PackyCode のご支援に感謝します!PackyCode は Claude Code、Codex、Gemini などのリレーサービスを提供する信頼性の高い API 中継プラットフォームです。本ソフト利用者向けに特別割引があります:<a href="https://www.packyapi.com/register?aff=sub2api">このリンク</a>で登録し、チャージ時に「sub2api」クーポンを入力すると 10% オフになります。</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## エコシステム
|
||||
|
||||
BIN
assets/partners/logos/packycode.png
Normal file
BIN
assets/partners/logos/packycode.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 8.1 KiB |
@ -1 +1 @@
|
||||
0.1.105
|
||||
0.1.106
|
||||
|
||||
@ -137,7 +137,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI, tempUnschedCache)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
|
||||
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
|
||||
tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client)
|
||||
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
|
||||
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
|
||||
|
||||
@ -541,6 +541,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
reqModel := modelResult.String()
|
||||
routingModel := service.NormalizeOpenAICompatRequestedModel(reqModel)
|
||||
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
@ -606,7 +607,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
apiKey.GroupID,
|
||||
"", // no previous_response_id
|
||||
sessionHash,
|
||||
reqModel,
|
||||
routingModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
@ -621,7 +622,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
if apiKey.Group != nil {
|
||||
defaultModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
if defaultModel != "" && defaultModel != reqModel {
|
||||
if defaultModel != "" && defaultModel != routingModel {
|
||||
reqLog.Info("openai_messages.fallback_to_default_model",
|
||||
zap.String("default_mapped_model", defaultModel),
|
||||
)
|
||||
|
||||
@ -24,20 +24,18 @@ const (
|
||||
RedirectURI = "https://platform.claude.com/oauth/code/callback"
|
||||
|
||||
// Scopes - Browser URL (includes org:create_api_key for user authorization)
|
||||
ScopeOAuth = "org:create_api_key user:profile user:inference user:sessions:claude_code user:mcp_servers"
|
||||
ScopeOAuth = "org:create_api_key user:profile user:inference user:sessions:claude_code user:mcp_servers user:file_upload"
|
||||
// Scopes - Internal API call (org:create_api_key not supported in API)
|
||||
ScopeAPI = "user:profile user:inference user:sessions:claude_code user:mcp_servers"
|
||||
ScopeAPI = "user:profile user:inference user:sessions:claude_code user:mcp_servers user:file_upload"
|
||||
// Scopes - Setup token (inference only)
|
||||
ScopeInference = "user:inference"
|
||||
|
||||
// Code Verifier character set (RFC 7636 compliant)
|
||||
codeVerifierCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
|
||||
|
||||
// Session TTL
|
||||
SessionTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
// OAuthSession stores OAuth flow state
|
||||
|
||||
type OAuthSession struct {
|
||||
State string `json:"state"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
@ -147,30 +145,14 @@ func GenerateSessionID() (string, error) {
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateCodeVerifier generates a PKCE code verifier using character set method
|
||||
// GenerateCodeVerifier generates a PKCE code verifier (RFC 7636).
|
||||
// Uses 32 random bytes → base64url-no-pad, producing a 43-char verifier.
|
||||
func GenerateCodeVerifier() (string, error) {
|
||||
const targetLen = 32
|
||||
charsetLen := len(codeVerifierCharset)
|
||||
limit := 256 - (256 % charsetLen)
|
||||
|
||||
result := make([]byte, 0, targetLen)
|
||||
randBuf := make([]byte, targetLen*2)
|
||||
|
||||
for len(result) < targetLen {
|
||||
if _, err := rand.Read(randBuf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, b := range randBuf {
|
||||
if int(b) < limit {
|
||||
result = append(result, codeVerifierCharset[int(b)%charsetLen])
|
||||
if len(result) >= targetLen {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return base64URLEncode(result), nil
|
||||
return base64URLEncode(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
|
||||
|
||||
@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
@ -257,9 +258,12 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||
// 存在唯一键约束 生成tombstone key 用来释放原key,长度远小于 128,满足 schema 限制
|
||||
tombstoneKey := fmt.Sprintf("__deleted__%d__%d", id, time.Now().UnixNano())
|
||||
// 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。
|
||||
affected, err := r.client.APIKey.Update().
|
||||
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
||||
SetKey(tombstoneKey).
|
||||
SetDeletedAt(time.Now()).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
|
||||
@ -151,6 +151,31 @@ func (s *APIKeyRepoSuite) TestDelete() {
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestCreate_AfterSoftDelete_AllowsSameKey() {
|
||||
user := s.mustCreateUser("recreate-after-soft-delete@test.com")
|
||||
const reusedKey = "sk-reuse-after-soft-delete"
|
||||
|
||||
first := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: reusedKey,
|
||||
Name: "First Key",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, first), "create first key")
|
||||
|
||||
s.Require().NoError(s.repo.Delete(s.ctx, first.ID), "soft delete first key")
|
||||
|
||||
second := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: reusedKey,
|
||||
Name: "Second Key",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, second), "create second key with same key")
|
||||
s.Require().NotZero(second.ID)
|
||||
s.Require().NotEqual(first.ID, second.ID, "recreated key should be a new row")
|
||||
}
|
||||
|
||||
// --- ListByUserID / CountByUserID ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestListByUserID() {
|
||||
|
||||
55
backend/internal/repository/internal500_counter_cache.go
Normal file
55
backend/internal/repository/internal500_counter_cache.go
Normal file
@ -0,0 +1,55 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
internal500CounterPrefix = "internal500_count:account:"
|
||||
internal500CounterTTLSeconds = 86400 // 24 小时兜底
|
||||
)
|
||||
|
||||
// internal500CounterIncrScript 使用 Lua 脚本原子性地增加计数并返回当前值
|
||||
// 如果 key 不存在,则创建并设置过期时间
|
||||
var internal500CounterIncrScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
|
||||
local count = redis.call('INCR', key)
|
||||
if count == 1 then
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
end
|
||||
|
||||
return count
|
||||
`)
|
||||
|
||||
type internal500CounterCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewInternal500CounterCache 创建 INTERNAL 500 连续失败计数器缓存实例
|
||||
func NewInternal500CounterCache(rdb *redis.Client) service.Internal500CounterCache {
|
||||
return &internal500CounterCache{rdb: rdb}
|
||||
}
|
||||
|
||||
// IncrementInternal500Count 原子递增计数并返回当前值
|
||||
func (c *internal500CounterCache) IncrementInternal500Count(ctx context.Context, accountID int64) (int64, error) {
|
||||
key := fmt.Sprintf("%s%d", internal500CounterPrefix, accountID)
|
||||
|
||||
result, err := internal500CounterIncrScript.Run(ctx, c.rdb, []string{key}, internal500CounterTTLSeconds).Int64()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("increment internal500 count: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ResetInternal500Count 清零计数器(成功响应时调用)
|
||||
func (c *internal500CounterCache) ResetInternal500Count(ctx context.Context, accountID int64) error {
|
||||
key := fmt.Sprintf("%s%d", internal500CounterPrefix, accountID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
@ -81,6 +81,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewAPIKeyCache,
|
||||
NewTempUnschedCache,
|
||||
NewTimeoutCounterCache,
|
||||
NewInternal500CounterCache,
|
||||
ProvideConcurrencyCache,
|
||||
ProvideSessionLimitCache,
|
||||
NewRPMCache,
|
||||
|
||||
@ -614,6 +614,7 @@ func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopP
|
||||
urlFallbackLoop:
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
usedBaseURL = baseURL
|
||||
allAttemptsInternal500 := true // 追踪本轮所有 attempt 是否全部命中 INTERNAL 500
|
||||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
@ -766,10 +767,19 @@ urlFallbackLoop:
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_backoff", p.prefix)
|
||||
return nil, p.ctx.Err()
|
||||
}
|
||||
// 追踪 INTERNAL 500:非匹配的 attempt 清除标记
|
||||
if !isAntigravityInternalServerError(resp.StatusCode, respBody) {
|
||||
allAttemptsInternal500 = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// INTERNAL 500 渐进惩罚:3 次重试全部命中特定 500 时递增计数器并惩罚
|
||||
if allAttemptsInternal500 && isAntigravityInternalServerError(resp.StatusCode, respBody) {
|
||||
s.handleInternal500RetryExhausted(p.ctx, p.prefix, p.account)
|
||||
}
|
||||
|
||||
// 其他 4xx 错误或重试用尽,直接返回
|
||||
resp = &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
@ -788,6 +798,11 @@ urlFallbackLoop:
|
||||
antigravity.DefaultURLAvailability.MarkSuccess(usedBaseURL)
|
||||
}
|
||||
|
||||
// 成功响应时清零 INTERNAL 500 连续失败计数器(覆盖所有成功路径,含 smart retry)
|
||||
if resp != nil && resp.StatusCode < 400 {
|
||||
s.resetInternal500Counter(p.ctx, p.prefix, p.account.ID)
|
||||
}
|
||||
|
||||
return &antigravityRetryLoopResult{resp: resp}, nil
|
||||
}
|
||||
|
||||
@ -862,6 +877,7 @@ type AntigravityGatewayService struct {
|
||||
settingService *SettingService
|
||||
cache GatewayCache // 用于模型级限流时清除粘性会话绑定
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
internal500Cache Internal500CounterCache // INTERNAL 500 渐进惩罚计数器
|
||||
}
|
||||
|
||||
func NewAntigravityGatewayService(
|
||||
@ -872,6 +888,7 @@ func NewAntigravityGatewayService(
|
||||
rateLimitService *RateLimitService,
|
||||
httpUpstream HTTPUpstream,
|
||||
settingService *SettingService,
|
||||
internal500Cache Internal500CounterCache,
|
||||
) *AntigravityGatewayService {
|
||||
return &AntigravityGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
@ -881,6 +898,7 @@ func NewAntigravityGatewayService(
|
||||
settingService: settingService,
|
||||
cache: cache,
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
internal500Cache: internal500Cache,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
97
backend/internal/service/antigravity_internal500_penalty.go
Normal file
97
backend/internal/service/antigravity_internal500_penalty.go
Normal file
@ -0,0 +1,97 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// INTERNAL 500 渐进惩罚:连续多轮全部返回特定 500 错误时的惩罚时长
|
||||
const (
|
||||
internal500PenaltyTier1Duration = 30 * time.Minute // 第 1 轮:临时不可调度 30 分钟
|
||||
internal500PenaltyTier2Duration = 2 * time.Hour // 第 2 轮:临时不可调度 2 小时
|
||||
internal500PenaltyTier3Threshold = 3 // 第 3+ 轮:永久禁用
|
||||
)
|
||||
|
||||
// isAntigravityInternalServerError 检测特定的 INTERNAL 500 错误
|
||||
// 必须同时匹配 error.code==500, error.message=="Internal error encountered.", error.status=="INTERNAL"
|
||||
func isAntigravityInternalServerError(statusCode int, body []byte) bool {
|
||||
if statusCode != http.StatusInternalServerError {
|
||||
return false
|
||||
}
|
||||
return gjson.GetBytes(body, "error.code").Int() == 500 &&
|
||||
gjson.GetBytes(body, "error.message").String() == "Internal error encountered." &&
|
||||
gjson.GetBytes(body, "error.status").String() == "INTERNAL"
|
||||
}
|
||||
|
||||
// applyInternal500Penalty 根据连续 INTERNAL 500 轮次数应用渐进惩罚
|
||||
// count=1: temp_unschedulable 10 分钟
|
||||
// count=2: temp_unschedulable 10 小时
|
||||
// count>=3: SetError 永久禁用
|
||||
func (s *AntigravityGatewayService) applyInternal500Penalty(
|
||||
ctx context.Context, prefix string, account *Account, count int64,
|
||||
) {
|
||||
switch {
|
||||
case count >= int64(internal500PenaltyTier3Threshold):
|
||||
reason := fmt.Sprintf("INTERNAL 500 consecutive failures: %d rounds", count)
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, reason); err != nil {
|
||||
slog.Error("internal500_set_error_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Warn("internal500_account_disabled",
|
||||
"account_id", account.ID, "account_name", account.Name, "consecutive_count", count)
|
||||
case count == 2:
|
||||
until := time.Now().Add(internal500PenaltyTier2Duration)
|
||||
reason := fmt.Sprintf("INTERNAL 500 x%d (temp unsched %v)", count, internal500PenaltyTier2Duration)
|
||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
||||
slog.Error("internal500_temp_unsched_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Warn("internal500_temp_unschedulable",
|
||||
"account_id", account.ID, "account_name", account.Name,
|
||||
"duration", internal500PenaltyTier2Duration, "consecutive_count", count)
|
||||
case count == 1:
|
||||
until := time.Now().Add(internal500PenaltyTier1Duration)
|
||||
reason := fmt.Sprintf("INTERNAL 500 x%d (temp unsched %v)", count, internal500PenaltyTier1Duration)
|
||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
||||
slog.Error("internal500_temp_unsched_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Info("internal500_temp_unschedulable",
|
||||
"account_id", account.ID, "account_name", account.Name,
|
||||
"duration", internal500PenaltyTier1Duration, "consecutive_count", count)
|
||||
}
|
||||
}
|
||||
|
||||
// handleInternal500RetryExhausted 处理 INTERNAL 500 重试耗尽:递增计数器并应用惩罚
|
||||
func (s *AntigravityGatewayService) handleInternal500RetryExhausted(
|
||||
ctx context.Context, prefix string, account *Account,
|
||||
) {
|
||||
if s.internal500Cache == nil {
|
||||
return
|
||||
}
|
||||
count, err := s.internal500Cache.IncrementInternal500Count(ctx, account.ID)
|
||||
if err != nil {
|
||||
slog.Error("internal500_counter_increment_failed",
|
||||
"prefix", prefix, "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
s.applyInternal500Penalty(ctx, prefix, account, count)
|
||||
}
|
||||
|
||||
// resetInternal500Counter 成功响应时清零 INTERNAL 500 计数器
|
||||
func (s *AntigravityGatewayService) resetInternal500Counter(
|
||||
ctx context.Context, prefix string, accountID int64,
|
||||
) {
|
||||
if s.internal500Cache == nil {
|
||||
return
|
||||
}
|
||||
if err := s.internal500Cache.ResetInternal500Count(ctx, accountID); err != nil {
|
||||
slog.Error("internal500_counter_reset_failed",
|
||||
"prefix", prefix, "account_id", accountID, "error", err)
|
||||
}
|
||||
}
|
||||
321
backend/internal/service/antigravity_internal500_penalty_test.go
Normal file
321
backend/internal/service/antigravity_internal500_penalty_test.go
Normal file
@ -0,0 +1,321 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- mock: Internal500CounterCache ---
|
||||
|
||||
type mockInternal500Cache struct {
|
||||
incrementCount int64
|
||||
incrementErr error
|
||||
resetErr error
|
||||
|
||||
incrementCalls []int64 // 记录 IncrementInternal500Count 被调用时的 accountID
|
||||
resetCalls []int64 // 记录 ResetInternal500Count 被调用时的 accountID
|
||||
}
|
||||
|
||||
func (m *mockInternal500Cache) IncrementInternal500Count(_ context.Context, accountID int64) (int64, error) {
|
||||
m.incrementCalls = append(m.incrementCalls, accountID)
|
||||
return m.incrementCount, m.incrementErr
|
||||
}
|
||||
|
||||
func (m *mockInternal500Cache) ResetInternal500Count(_ context.Context, accountID int64) error {
|
||||
m.resetCalls = append(m.resetCalls, accountID)
|
||||
return m.resetErr
|
||||
}
|
||||
|
||||
// --- mock: 专用于 internal500 惩罚测试的 AccountRepository ---
|
||||
|
||||
type internal500AccountRepoStub struct {
|
||||
AccountRepository // 嵌入接口,未实现的方法会 panic(不应被调用)
|
||||
|
||||
tempUnschedCalls []tempUnschedCall
|
||||
setErrorCalls []setErrorCall
|
||||
}
|
||||
|
||||
type tempUnschedCall struct {
|
||||
accountID int64
|
||||
until time.Time
|
||||
reason string
|
||||
}
|
||||
|
||||
type setErrorCall struct {
|
||||
accountID int64
|
||||
reason string
|
||||
}
|
||||
|
||||
func (r *internal500AccountRepoStub) SetTempUnschedulable(_ context.Context, id int64, until time.Time, reason string) error {
|
||||
r.tempUnschedCalls = append(r.tempUnschedCalls, tempUnschedCall{accountID: id, until: until, reason: reason})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *internal500AccountRepoStub) SetError(_ context.Context, id int64, errorMsg string) error {
|
||||
r.setErrorCalls = append(r.setErrorCalls, setErrorCall{accountID: id, reason: errorMsg})
|
||||
return nil
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TestIsAntigravityInternalServerError
|
||||
// =============================================================================
|
||||
|
||||
func TestIsAntigravityInternalServerError(t *testing.T) {
|
||||
t.Run("匹配完整的 INTERNAL 500 body", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"INTERNAL"}}`)
|
||||
require.True(t, isAntigravityInternalServerError(500, body))
|
||||
})
|
||||
|
||||
t.Run("statusCode 不是 500", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"INTERNAL"}}`)
|
||||
require.False(t, isAntigravityInternalServerError(429, body))
|
||||
require.False(t, isAntigravityInternalServerError(503, body))
|
||||
require.False(t, isAntigravityInternalServerError(200, body))
|
||||
})
|
||||
|
||||
t.Run("body 中 message 不匹配", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"code":500,"message":"Some other error","status":"INTERNAL"}}`)
|
||||
require.False(t, isAntigravityInternalServerError(500, body))
|
||||
})
|
||||
|
||||
t.Run("body 中 status 不匹配", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"UNAVAILABLE"}}`)
|
||||
require.False(t, isAntigravityInternalServerError(500, body))
|
||||
})
|
||||
|
||||
t.Run("body 中 code 不匹配", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"code":503,"message":"Internal error encountered.","status":"INTERNAL"}}`)
|
||||
require.False(t, isAntigravityInternalServerError(500, body))
|
||||
})
|
||||
|
||||
t.Run("空 body", func(t *testing.T) {
|
||||
require.False(t, isAntigravityInternalServerError(500, []byte{}))
|
||||
require.False(t, isAntigravityInternalServerError(500, nil))
|
||||
})
|
||||
|
||||
t.Run("其他 500 错误格式(纯文本)", func(t *testing.T) {
|
||||
body := []byte(`Internal Server Error`)
|
||||
require.False(t, isAntigravityInternalServerError(500, body))
|
||||
})
|
||||
|
||||
t.Run("其他 500 错误格式(不同 JSON 结构)", func(t *testing.T) {
|
||||
body := []byte(`{"message":"Internal Server Error","statusCode":500}`)
|
||||
require.False(t, isAntigravityInternalServerError(500, body))
|
||||
})
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TestApplyInternal500Penalty
|
||||
// =============================================================================
|
||||
|
||||
func TestApplyInternal500Penalty(t *testing.T) {
|
||||
t.Run("count=1 → SetTempUnschedulable 10 分钟", func(t *testing.T) {
|
||||
repo := &internal500AccountRepoStub{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 1, Name: "acc-1"}
|
||||
|
||||
before := time.Now()
|
||||
svc.applyInternal500Penalty(context.Background(), "[test]", account, 1)
|
||||
after := time.Now()
|
||||
|
||||
require.Len(t, repo.tempUnschedCalls, 1)
|
||||
require.Empty(t, repo.setErrorCalls)
|
||||
|
||||
call := repo.tempUnschedCalls[0]
|
||||
require.Equal(t, int64(1), call.accountID)
|
||||
require.Contains(t, call.reason, "INTERNAL 500")
|
||||
// until 应在 [before+10m, after+10m] 范围内
|
||||
require.True(t, call.until.After(before.Add(internal500PenaltyTier1Duration).Add(-time.Second)))
|
||||
require.True(t, call.until.Before(after.Add(internal500PenaltyTier1Duration).Add(time.Second)))
|
||||
})
|
||||
|
||||
t.Run("count=2 → SetTempUnschedulable 10 小时", func(t *testing.T) {
|
||||
repo := &internal500AccountRepoStub{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 2, Name: "acc-2"}
|
||||
|
||||
before := time.Now()
|
||||
svc.applyInternal500Penalty(context.Background(), "[test]", account, 2)
|
||||
after := time.Now()
|
||||
|
||||
require.Len(t, repo.tempUnschedCalls, 1)
|
||||
require.Empty(t, repo.setErrorCalls)
|
||||
|
||||
call := repo.tempUnschedCalls[0]
|
||||
require.Equal(t, int64(2), call.accountID)
|
||||
require.Contains(t, call.reason, "INTERNAL 500")
|
||||
require.True(t, call.until.After(before.Add(internal500PenaltyTier2Duration).Add(-time.Second)))
|
||||
require.True(t, call.until.Before(after.Add(internal500PenaltyTier2Duration).Add(time.Second)))
|
||||
})
|
||||
|
||||
t.Run("count=3 → SetError 永久禁用", func(t *testing.T) {
|
||||
repo := &internal500AccountRepoStub{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 3, Name: "acc-3"}
|
||||
|
||||
svc.applyInternal500Penalty(context.Background(), "[test]", account, 3)
|
||||
|
||||
require.Empty(t, repo.tempUnschedCalls)
|
||||
require.Len(t, repo.setErrorCalls, 1)
|
||||
|
||||
call := repo.setErrorCalls[0]
|
||||
require.Equal(t, int64(3), call.accountID)
|
||||
require.Contains(t, call.reason, "INTERNAL 500 consecutive failures: 3")
|
||||
})
|
||||
|
||||
t.Run("count=5 → SetError 永久禁用(>=3 都走永久禁用)", func(t *testing.T) {
|
||||
repo := &internal500AccountRepoStub{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 5, Name: "acc-5"}
|
||||
|
||||
svc.applyInternal500Penalty(context.Background(), "[test]", account, 5)
|
||||
|
||||
require.Empty(t, repo.tempUnschedCalls)
|
||||
require.Len(t, repo.setErrorCalls, 1)
|
||||
|
||||
call := repo.setErrorCalls[0]
|
||||
require.Equal(t, int64(5), call.accountID)
|
||||
require.Contains(t, call.reason, "INTERNAL 500 consecutive failures: 5")
|
||||
})
|
||||
|
||||
t.Run("count=0 → 不调用任何方法", func(t *testing.T) {
|
||||
repo := &internal500AccountRepoStub{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 10, Name: "acc-10"}
|
||||
|
||||
svc.applyInternal500Penalty(context.Background(), "[test]", account, 0)
|
||||
|
||||
require.Empty(t, repo.tempUnschedCalls)
|
||||
require.Empty(t, repo.setErrorCalls)
|
||||
})
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TestHandleInternal500RetryExhausted
|
||||
// =============================================================================
|
||||
|
||||
func TestHandleInternal500RetryExhausted(t *testing.T) {
|
||||
t.Run("internal500Cache 为 nil → 不 panic,不调用任何方法", func(t *testing.T) {
|
||||
repo := &internal500AccountRepoStub{}
|
||||
svc := &AntigravityGatewayService{
|
||||
accountRepo: repo,
|
||||
internal500Cache: nil,
|
||||
}
|
||||
account := &Account{ID: 1, Name: "acc-1"}
|
||||
|
||||
// 不应 panic
|
||||
require.NotPanics(t, func() {
|
||||
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
|
||||
})
|
||||
require.Empty(t, repo.tempUnschedCalls)
|
||||
require.Empty(t, repo.setErrorCalls)
|
||||
})
|
||||
|
||||
t.Run("IncrementInternal500Count 返回 error → 不调用惩罚方法", func(t *testing.T) {
|
||||
repo := &internal500AccountRepoStub{}
|
||||
cache := &mockInternal500Cache{
|
||||
incrementErr: errors.New("redis connection error"),
|
||||
}
|
||||
svc := &AntigravityGatewayService{
|
||||
accountRepo: repo,
|
||||
internal500Cache: cache,
|
||||
}
|
||||
account := &Account{ID: 2, Name: "acc-2"}
|
||||
|
||||
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
|
||||
|
||||
require.Len(t, cache.incrementCalls, 1)
|
||||
require.Equal(t, int64(2), cache.incrementCalls[0])
|
||||
require.Empty(t, repo.tempUnschedCalls)
|
||||
require.Empty(t, repo.setErrorCalls)
|
||||
})
|
||||
|
||||
t.Run("IncrementInternal500Count 返回 count=1 → 触发 tier1 惩罚", func(t *testing.T) {
|
||||
repo := &internal500AccountRepoStub{}
|
||||
cache := &mockInternal500Cache{
|
||||
incrementCount: 1,
|
||||
}
|
||||
svc := &AntigravityGatewayService{
|
||||
accountRepo: repo,
|
||||
internal500Cache: cache,
|
||||
}
|
||||
account := &Account{ID: 3, Name: "acc-3"}
|
||||
|
||||
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
|
||||
|
||||
require.Len(t, cache.incrementCalls, 1)
|
||||
require.Equal(t, int64(3), cache.incrementCalls[0])
|
||||
// tier1: SetTempUnschedulable
|
||||
require.Len(t, repo.tempUnschedCalls, 1)
|
||||
require.Equal(t, int64(3), repo.tempUnschedCalls[0].accountID)
|
||||
require.Empty(t, repo.setErrorCalls)
|
||||
})
|
||||
|
||||
t.Run("IncrementInternal500Count 返回 count=3 → 触发 tier3 永久禁用", func(t *testing.T) {
|
||||
repo := &internal500AccountRepoStub{}
|
||||
cache := &mockInternal500Cache{
|
||||
incrementCount: 3,
|
||||
}
|
||||
svc := &AntigravityGatewayService{
|
||||
accountRepo: repo,
|
||||
internal500Cache: cache,
|
||||
}
|
||||
account := &Account{ID: 4, Name: "acc-4"}
|
||||
|
||||
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
|
||||
|
||||
require.Len(t, cache.incrementCalls, 1)
|
||||
require.Empty(t, repo.tempUnschedCalls)
|
||||
require.Len(t, repo.setErrorCalls, 1)
|
||||
require.Equal(t, int64(4), repo.setErrorCalls[0].accountID)
|
||||
})
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TestResetInternal500Counter
|
||||
// =============================================================================
|
||||
|
||||
func TestResetInternal500Counter(t *testing.T) {
|
||||
t.Run("internal500Cache 为 nil → 不 panic", func(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{
|
||||
internal500Cache: nil,
|
||||
}
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
svc.resetInternal500Counter(context.Background(), "[test]", 1)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("ResetInternal500Count 返回 error → 不 panic(仅日志)", func(t *testing.T) {
|
||||
cache := &mockInternal500Cache{
|
||||
resetErr: errors.New("redis timeout"),
|
||||
}
|
||||
svc := &AntigravityGatewayService{
|
||||
internal500Cache: cache,
|
||||
}
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
svc.resetInternal500Counter(context.Background(), "[test]", 42)
|
||||
})
|
||||
require.Len(t, cache.resetCalls, 1)
|
||||
require.Equal(t, int64(42), cache.resetCalls[0])
|
||||
})
|
||||
|
||||
t.Run("正常调用 → 调用 ResetInternal500Count", func(t *testing.T) {
|
||||
cache := &mockInternal500Cache{}
|
||||
svc := &AntigravityGatewayService{
|
||||
internal500Cache: cache,
|
||||
}
|
||||
|
||||
svc.resetInternal500Counter(context.Background(), "[test]", 99)
|
||||
|
||||
require.Len(t, cache.resetCalls, 1)
|
||||
require.Equal(t, int64(99), cache.resetCalls[0])
|
||||
})
|
||||
}
|
||||
11
backend/internal/service/internal500_counter.go
Normal file
11
backend/internal/service/internal500_counter.go
Normal file
@ -0,0 +1,11 @@
|
||||
package service
|
||||
|
||||
import "context"
|
||||
|
||||
// Internal500CounterCache 追踪 Antigravity 账号连续 INTERNAL 500 失败轮数
|
||||
type Internal500CounterCache interface {
|
||||
// IncrementInternal500Count 原子递增计数并返回当前值
|
||||
IncrementInternal500Count(ctx context.Context, accountID int64) (int64, error)
|
||||
// ResetInternal500Count 清零计数器(成功响应时调用)
|
||||
ResetInternal500Count(ctx context.Context, accountID int64) error
|
||||
}
|
||||
103
backend/internal/service/openai_compat_model.go
Normal file
103
backend/internal/service/openai_compat_model.go
Normal file
@ -0,0 +1,103 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
)
|
||||
|
||||
func NormalizeOpenAICompatRequestedModel(model string) string {
|
||||
trimmed := strings.TrimSpace(model)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
normalized, _, ok := splitOpenAICompatReasoningModel(trimmed)
|
||||
if !ok || normalized == "" {
|
||||
return trimmed
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func applyOpenAICompatModelNormalization(req *apicompat.AnthropicRequest) {
|
||||
if req == nil {
|
||||
return
|
||||
}
|
||||
|
||||
originalModel := strings.TrimSpace(req.Model)
|
||||
if originalModel == "" {
|
||||
return
|
||||
}
|
||||
|
||||
normalizedModel, derivedEffort, hasReasoningSuffix := splitOpenAICompatReasoningModel(originalModel)
|
||||
if hasReasoningSuffix && normalizedModel != "" {
|
||||
req.Model = normalizedModel
|
||||
}
|
||||
|
||||
if req.OutputConfig != nil && strings.TrimSpace(req.OutputConfig.Effort) != "" {
|
||||
return
|
||||
}
|
||||
|
||||
claudeEffort := openAIReasoningEffortToClaudeOutputEffort(derivedEffort)
|
||||
if claudeEffort == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if req.OutputConfig == nil {
|
||||
req.OutputConfig = &apicompat.AnthropicOutputConfig{}
|
||||
}
|
||||
req.OutputConfig.Effort = claudeEffort
|
||||
}
|
||||
|
||||
func splitOpenAICompatReasoningModel(model string) (normalizedModel string, reasoningEffort string, ok bool) {
|
||||
trimmed := strings.TrimSpace(model)
|
||||
if trimmed == "" {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
modelID := trimmed
|
||||
if strings.Contains(modelID, "/") {
|
||||
parts := strings.Split(modelID, "/")
|
||||
modelID = parts[len(parts)-1]
|
||||
}
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if !strings.HasPrefix(strings.ToLower(modelID), "gpt-") {
|
||||
return trimmed, "", false
|
||||
}
|
||||
|
||||
parts := strings.FieldsFunc(strings.ToLower(modelID), func(r rune) bool {
|
||||
switch r {
|
||||
case '-', '_', ' ':
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
})
|
||||
if len(parts) == 0 {
|
||||
return trimmed, "", false
|
||||
}
|
||||
|
||||
last := strings.NewReplacer("-", "", "_", "", " ", "").Replace(parts[len(parts)-1])
|
||||
switch last {
|
||||
case "none", "minimal":
|
||||
case "low", "medium", "high":
|
||||
reasoningEffort = last
|
||||
case "xhigh", "extrahigh":
|
||||
reasoningEffort = "xhigh"
|
||||
default:
|
||||
return trimmed, "", false
|
||||
}
|
||||
|
||||
return normalizeCodexModel(modelID), reasoningEffort, true
|
||||
}
|
||||
|
||||
func openAIReasoningEffortToClaudeOutputEffort(effort string) string {
|
||||
switch strings.TrimSpace(effort) {
|
||||
case "low", "medium", "high":
|
||||
return effort
|
||||
case "xhigh":
|
||||
return "max"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
129
backend/internal/service/openai_compat_model_test.go
Normal file
129
backend/internal/service/openai_compat_model_test.go
Normal file
@ -0,0 +1,129 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestNormalizeOpenAICompatRequestedModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{name: "gpt reasoning alias strips xhigh", input: "gpt-5.4-xhigh", want: "gpt-5.4"},
|
||||
{name: "gpt reasoning alias strips none", input: "gpt-5.4-none", want: "gpt-5.4"},
|
||||
{name: "codex max model stays intact", input: "gpt-5.1-codex-max", want: "gpt-5.1-codex-max"},
|
||||
{name: "non openai model unchanged", input: "claude-opus-4-6", want: "claude-opus-4-6"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, NormalizeOpenAICompatRequestedModel(tt.input))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyOpenAICompatModelNormalization(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("derives xhigh from model suffix when output config missing", func(t *testing.T) {
|
||||
req := &apicompat.AnthropicRequest{Model: "gpt-5.4-xhigh"}
|
||||
|
||||
applyOpenAICompatModelNormalization(req)
|
||||
|
||||
require.Equal(t, "gpt-5.4", req.Model)
|
||||
require.NotNil(t, req.OutputConfig)
|
||||
require.Equal(t, "max", req.OutputConfig.Effort)
|
||||
})
|
||||
|
||||
t.Run("explicit output config wins over model suffix", func(t *testing.T) {
|
||||
req := &apicompat.AnthropicRequest{
|
||||
Model: "gpt-5.4-xhigh",
|
||||
OutputConfig: &apicompat.AnthropicOutputConfig{Effort: "low"},
|
||||
}
|
||||
|
||||
applyOpenAICompatModelNormalization(req)
|
||||
|
||||
require.Equal(t, "gpt-5.4", req.Model)
|
||||
require.NotNil(t, req.OutputConfig)
|
||||
require.Equal(t, "low", req.OutputConfig.Effort)
|
||||
})
|
||||
|
||||
t.Run("non openai model is untouched", func(t *testing.T) {
|
||||
req := &apicompat.AnthropicRequest{Model: "claude-opus-4-6"}
|
||||
|
||||
applyOpenAICompatModelNormalization(req)
|
||||
|
||||
require.Equal(t, "claude-opus-4-6", req.Model)
|
||||
require.Nil(t, req.OutputConfig)
|
||||
})
|
||||
}
|
||||
|
||||
func TestForwardAsAnthropic_NormalizesRoutingAndEffortForGpt54XHigh(t *testing.T) {
|
||||
t.Parallel()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
body := []byte(`{"model":"gpt-5.4-xhigh","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_compat"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token",
|
||||
"chatgpt_account_id": "chatgpt-acc",
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5.4": "gpt-5.4",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "gpt-5.4-xhigh", result.Model)
|
||||
require.Equal(t, "gpt-5.4", result.UpstreamModel)
|
||||
require.Equal(t, "gpt-5.4", result.BillingModel)
|
||||
require.NotNil(t, result.ReasoningEffort)
|
||||
require.Equal(t, "xhigh", *result.ReasoningEffort)
|
||||
|
||||
require.Equal(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||
require.Equal(t, "xhigh", gjson.GetBytes(upstream.lastBody, "reasoning.effort").String())
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, "gpt-5.4-xhigh", gjson.GetBytes(rec.Body.Bytes(), "model").String())
|
||||
require.Equal(t, "ok", gjson.GetBytes(rec.Body.Bytes(), "content.0.text").String())
|
||||
t.Logf("upstream body: %s", string(upstream.lastBody))
|
||||
t.Logf("response body: %s", rec.Body.String())
|
||||
}
|
||||
@ -40,6 +40,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
return nil, fmt.Errorf("parse anthropic request: %w", err)
|
||||
}
|
||||
originalModel := anthropicReq.Model
|
||||
applyOpenAICompatModelNormalization(&anthropicReq)
|
||||
clientStream := anthropicReq.Stream // client's original stream preference
|
||||
|
||||
// 2. Convert Anthropic → Responses
|
||||
@ -59,7 +60,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
}
|
||||
|
||||
// 3. Model mapping
|
||||
mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
|
||||
mappedModel := resolveOpenAIForwardModel(account, anthropicReq.Model, defaultMappedModel)
|
||||
responsesReq.Model = mappedModel
|
||||
|
||||
logger.L().Debug("openai messages: model mapping applied",
|
||||
|
||||
@ -895,14 +895,16 @@ func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetad
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_BillsMappedRequestsUsingUpstreamModelFallback(t *testing.T) {
|
||||
func TestOpenAIGatewayServiceRecordUsage_BillsMappedRequestsUsingRequestedModel(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10}
|
||||
|
||||
expectedCost, err := svc.billingService.CalculateCost("gpt-5.1-codex", UsageTokens{
|
||||
// Billing should use the requested model ("gpt-5.1"), not the upstream mapped model ("gpt-5.1-codex").
|
||||
// This ensures pricing is always based on the model the user requested.
|
||||
expectedCost, err := svc.billingService.CalculateCost("gpt-5.1", UsageTokens{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
}, 1.1)
|
||||
|
||||
@ -4153,9 +4153,6 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
}
|
||||
|
||||
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
|
||||
if result.BillingModel != "" {
|
||||
billingModel = strings.TrimSpace(result.BillingModel)
|
||||
}
|
||||
serviceTier := ""
|
||||
if result.ServiceTier != nil {
|
||||
serviceTier = strings.TrimSpace(*result.ServiceTier)
|
||||
|
||||
@ -21,8 +21,8 @@ func optionalNonEqualStringPtr(value, compare string) *string {
|
||||
}
|
||||
|
||||
func forwardResultBillingModel(requestedModel, upstreamModel string) string {
|
||||
if trimmedUpstream := strings.TrimSpace(upstreamModel); trimmedUpstream != "" {
|
||||
return trimmedUpstream
|
||||
if trimmed := strings.TrimSpace(requestedModel); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
return strings.TrimSpace(requestedModel)
|
||||
return strings.TrimSpace(upstreamModel)
|
||||
}
|
||||
|
||||
@ -4383,6 +4383,7 @@ export default {
|
||||
provider: 'Type',
|
||||
active: 'Active',
|
||||
endpoint: 'Endpoint',
|
||||
bucket: 'Bucket',
|
||||
storagePath: 'Storage Path',
|
||||
capacityUsage: 'Capacity / Used',
|
||||
capacityUnlimited: 'Unlimited',
|
||||
|
||||
@ -4547,6 +4547,7 @@ export default {
|
||||
provider: '存储类型',
|
||||
active: '生效状态',
|
||||
endpoint: '端点',
|
||||
bucket: '存储桶',
|
||||
storagePath: '存储路径',
|
||||
capacityUsage: '容量 / 已用',
|
||||
capacityUnlimited: '无限制',
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
import { computed, onMounted, reactive, ref, watch } from 'vue'
|
||||
import { opsAPI, type OpsRuntimeLogConfig, type OpsSystemLog, type OpsSystemLogSinkHealth } from '@/api/admin/ops'
|
||||
import Pagination from '@/components/common/Pagination.vue'
|
||||
import Select from '@/components/common/Select.vue'
|
||||
import { useAppStore } from '@/stores'
|
||||
|
||||
const appStore = useAppStore()
|
||||
@ -56,6 +57,37 @@ const filters = reactive({
|
||||
q: ''
|
||||
})
|
||||
|
||||
const runtimeLevelOptions = [
|
||||
{ value: 'debug', label: 'debug' },
|
||||
{ value: 'info', label: 'info' },
|
||||
{ value: 'warn', label: 'warn' },
|
||||
{ value: 'error', label: 'error' }
|
||||
]
|
||||
|
||||
const stacktraceLevelOptions = [
|
||||
{ value: 'none', label: 'none' },
|
||||
{ value: 'error', label: 'error' },
|
||||
{ value: 'fatal', label: 'fatal' }
|
||||
]
|
||||
|
||||
const timeRangeOptions = [
|
||||
{ value: '5m', label: '5m' },
|
||||
{ value: '30m', label: '30m' },
|
||||
{ value: '1h', label: '1h' },
|
||||
{ value: '6h', label: '6h' },
|
||||
{ value: '24h', label: '24h' },
|
||||
{ value: '7d', label: '7d' },
|
||||
{ value: '30d', label: '30d' }
|
||||
]
|
||||
|
||||
const filterLevelOptions = [
|
||||
{ value: '', label: '全部' },
|
||||
{ value: 'debug', label: 'debug' },
|
||||
{ value: 'info', label: 'info' },
|
||||
{ value: 'warn', label: 'warn' },
|
||||
{ value: 'error', label: 'error' }
|
||||
]
|
||||
|
||||
const levelBadgeClass = (level: string) => {
|
||||
const v = String(level || '').toLowerCase()
|
||||
if (v === 'error' || v === 'fatal') return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-300'
|
||||
@ -347,20 +379,11 @@ onMounted(async () => {
|
||||
<div class="grid grid-cols-1 gap-3 md:grid-cols-2 xl:grid-cols-6">
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
级别
|
||||
<select v-model="runtimeConfig.level" class="input mt-1">
|
||||
<option value="debug">debug</option>
|
||||
<option value="info">info</option>
|
||||
<option value="warn">warn</option>
|
||||
<option value="error">error</option>
|
||||
</select>
|
||||
<Select v-model="runtimeConfig.level" class="mt-1" :options="runtimeLevelOptions" />
|
||||
</label>
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
堆栈阈值
|
||||
<select v-model="runtimeConfig.stacktrace_level" class="input mt-1">
|
||||
<option value="none">none</option>
|
||||
<option value="error">error</option>
|
||||
<option value="fatal">fatal</option>
|
||||
</select>
|
||||
<Select v-model="runtimeConfig.stacktrace_level" class="mt-1" :options="stacktraceLevelOptions" />
|
||||
</label>
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
采样初始
|
||||
@ -403,15 +426,7 @@ onMounted(async () => {
|
||||
<div class="mb-4 grid grid-cols-1 gap-3 md:grid-cols-5">
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
时间范围
|
||||
<select v-model="filters.time_range" class="input mt-1">
|
||||
<option value="5m">5m</option>
|
||||
<option value="30m">30m</option>
|
||||
<option value="1h">1h</option>
|
||||
<option value="6h">6h</option>
|
||||
<option value="24h">24h</option>
|
||||
<option value="7d">7d</option>
|
||||
<option value="30d">30d</option>
|
||||
</select>
|
||||
<Select v-model="filters.time_range" class="mt-1" :options="timeRangeOptions" />
|
||||
</label>
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
开始时间(可选)
|
||||
@ -423,13 +438,7 @@ onMounted(async () => {
|
||||
</label>
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
级别
|
||||
<select v-model="filters.level" class="input mt-1">
|
||||
<option value="">全部</option>
|
||||
<option value="debug">debug</option>
|
||||
<option value="info">info</option>
|
||||
<option value="warn">warn</option>
|
||||
<option value="error">error</option>
|
||||
</select>
|
||||
<Select v-model="filters.level" class="mt-1" :options="filterLevelOptions" />
|
||||
</label>
|
||||
<label class="text-xs text-gray-600 dark:text-gray-300">
|
||||
组件
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user