修复h1
This commit is contained in:
parent
1e19d9caca
commit
b856586412
@ -127,9 +127,9 @@ COPY --chown=sub2api:sub2api deploy/ls-bin/language_server_linux_* /tmp/ls-bin/
|
||||
COPY --chown=sub2api:sub2api deploy/ls-bin/cert.pem /app/ls/extensions/antigravity/dist/languageServer/
|
||||
RUN mkdir -p /app/ls/extensions/antigravity/bin && \
|
||||
if [ "$TARGETARCH" = "arm64" ]; then \
|
||||
cp /tmp/ls-bin/language_server_linux_arm /app/ls/extensions/antigravity/bin/language_server_linux_arm; \
|
||||
cp /tmp/ls-bin/language_server_linux_arm /app/ls/extensions/antigravity/bin/language_server_linux_arm; \
|
||||
else \
|
||||
cp /tmp/ls-bin/language_server_linux_x64 /app/ls/extensions/antigravity/bin/language_server_linux_x64; \
|
||||
cp /tmp/ls-bin/language_server_linux_x64 /app/ls/extensions/antigravity/bin/language_server_linux_x64; \
|
||||
fi && \
|
||||
chmod +x /app/ls/extensions/antigravity/bin/language_server_linux_* && \
|
||||
rm -rf /tmp/ls-bin
|
||||
|
||||
@ -681,10 +681,12 @@ async function proxyRequest(req, res) {
|
||||
})();
|
||||
await new Promise(r => setTimeout(r, jitterMs));
|
||||
|
||||
// ── H2 优先策略 ──────────────────────────────────────────────────
|
||||
// Anthropic/Google API 均支持 HTTP/2。
|
||||
// 直接走 H2 = Node.js 原生帧顺序,与真实 CLI 完全一致。
|
||||
// 其他 host 维持原有 H1→H2 自动切换逻辑。
|
||||
// ── H2 / H1 路由策略 ──────────────────────────────────────────────
|
||||
// 当存在 per-account 上游代理(X-Upstream-Proxy)时,强制走 H1:
|
||||
// 1. H2 的 getOrCreateH2Session 不支持 CONNECT 隧道代理
|
||||
// 2. 真实 CLI 用 undici 默认的 HTTP/1.1(allowH2=false),H1 更贴合指纹
|
||||
// 无代理的直连请求仍可走 H2 以获得多路复用性能。
|
||||
const hasUpstreamProxy = !!(req.headers['x-upstream-proxy'] || UPSTREAM_PROXY);
|
||||
const H2_PREFER_HOSTS = new Set([
|
||||
'api.anthropic.com',
|
||||
'cloudaicompanion.googleapis.com',
|
||||
@ -692,9 +694,12 @@ async function proxyRequest(req, res) {
|
||||
'cloudcode-pa.googleapis.com',
|
||||
'daily-cloudcode-pa.googleapis.com',
|
||||
]);
|
||||
if (H2_PREFER_HOSTS.has(targetHost) || h2Hosts.has(targetHost)) {
|
||||
if (!hasUpstreamProxy && (H2_PREFER_HOSTS.has(targetHost) || h2Hosts.has(targetHost))) {
|
||||
await sendViaH2(targetHost, req.method, req.url, req.headers, body, res, savedHeaders);
|
||||
} else {
|
||||
if (hasUpstreamProxy && H2_PREFER_HOSTS.has(targetHost)) {
|
||||
log('info', 'h2_downgrade_to_h1', { host: targetHost, reason: 'upstream_proxy_set' });
|
||||
}
|
||||
await sendViaH1(targetHost, req.method, req.url, req.headers, body, res, savedHeaders);
|
||||
}
|
||||
}
|
||||
|
||||
@ -79,6 +79,7 @@ func provideCleanup(
|
||||
soraMediaCleanup *service.SoraMediaCleanupService,
|
||||
schedulerSnapshot *service.SchedulerSnapshotService,
|
||||
tokenRefresh *service.TokenRefreshService,
|
||||
lsPoolBootstrap *service.LSPoolBootstrapService,
|
||||
accountExpiry *service.AccountExpiryService,
|
||||
subscriptionExpiry *service.SubscriptionExpiryService,
|
||||
usageCleanup *service.UsageCleanupService,
|
||||
@ -171,6 +172,12 @@ func provideCleanup(
|
||||
tokenRefresh.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"LSPoolBootstrapService", func() error {
|
||||
if lsPoolBootstrap != nil {
|
||||
lsPoolBootstrap.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"AccountExpiryService", func() error {
|
||||
accountExpiry.Stop()
|
||||
return nil
|
||||
|
||||
@ -246,10 +246,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI)
|
||||
lsPoolBootstrapService := service.ProvideLSPoolBootstrapService(accountRepository, configConfig)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, lsPoolBootstrapService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@ -287,6 +288,7 @@ func provideCleanup(
|
||||
soraMediaCleanup *service.SoraMediaCleanupService,
|
||||
schedulerSnapshot *service.SchedulerSnapshotService,
|
||||
tokenRefresh *service.TokenRefreshService,
|
||||
lsPoolBootstrap *service.LSPoolBootstrapService,
|
||||
accountExpiry *service.AccountExpiryService,
|
||||
subscriptionExpiry *service.SubscriptionExpiryService,
|
||||
usageCleanup *service.UsageCleanupService,
|
||||
@ -378,6 +380,12 @@ func provideCleanup(
|
||||
tokenRefresh.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"LSPoolBootstrapService", func() error {
|
||||
if lsPoolBootstrap != nil {
|
||||
lsPoolBootstrap.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"AccountExpiryService", func() error {
|
||||
accountExpiry.Stop()
|
||||
return nil
|
||||
|
||||
@ -47,6 +47,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
idempotencyCleanupSvc := service.NewIdempotencyCleanupService(nil, cfg)
|
||||
schedulerSnapshotSvc := service.NewSchedulerSnapshotService(nil, nil, nil, nil, cfg)
|
||||
opsSystemLogSinkSvc := service.NewOpsSystemLogSink(nil)
|
||||
lsPoolBootstrapSvc := service.NewLSPoolBootstrapService(nil, nil, cfg)
|
||||
|
||||
cleanup := provideCleanup(
|
||||
nil, // entClient
|
||||
@ -60,6 +61,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
&service.SoraMediaCleanupService{},
|
||||
schedulerSnapshotSvc,
|
||||
tokenRefreshSvc,
|
||||
lsPoolBootstrapSvc,
|
||||
accountExpirySvc,
|
||||
subscriptionExpirySvc,
|
||||
&service.UsageCleanupService{},
|
||||
|
||||
@ -515,7 +515,7 @@ func validateDataProxy(item DataProxy) error {
|
||||
return errors.New("proxy port is invalid")
|
||||
}
|
||||
switch item.Protocol {
|
||||
case "http", "https", "socks5", "socks5h":
|
||||
case "http", "socks5", "socks5h":
|
||||
default:
|
||||
return fmt.Errorf("proxy protocol is invalid: %s", item.Protocol)
|
||||
}
|
||||
|
||||
@ -27,7 +27,7 @@ func NewProxyHandler(adminService service.AdminService) *ProxyHandler {
|
||||
// CreateProxyRequest represents create proxy request
|
||||
type CreateProxyRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Protocol string `json:"protocol" binding:"required,oneof=http https socks5 socks5h"`
|
||||
Protocol string `json:"protocol" binding:"required,oneof=http socks5 socks5h"`
|
||||
Host string `json:"host" binding:"required"`
|
||||
Port int `json:"port" binding:"required,min=1,max=65535"`
|
||||
Username string `json:"username"`
|
||||
@ -37,7 +37,7 @@ type CreateProxyRequest struct {
|
||||
// UpdateProxyRequest represents update proxy request
|
||||
type UpdateProxyRequest struct {
|
||||
Name string `json:"name"`
|
||||
Protocol string `json:"protocol" binding:"omitempty,oneof=http https socks5 socks5h"`
|
||||
Protocol string `json:"protocol" binding:"omitempty,oneof=http socks5 socks5h"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port" binding:"omitempty,min=1,max=65535"`
|
||||
Username string `json:"username"`
|
||||
@ -299,7 +299,7 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
|
||||
|
||||
// BatchCreateProxyItem represents a single proxy in batch create request
|
||||
type BatchCreateProxyItem struct {
|
||||
Protocol string `json:"protocol" binding:"required,oneof=http https socks5 socks5h"`
|
||||
Protocol string `json:"protocol" binding:"required,oneof=http socks5 socks5h"`
|
||||
Host string `json:"host" binding:"required"`
|
||||
Port int `json:"port" binding:"required,min=1,max=65535"`
|
||||
Username string `json:"username"`
|
||||
|
||||
@ -53,8 +53,8 @@ const (
|
||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.107.0
|
||||
var defaultUserAgentVersion = "1.107.0"
|
||||
|
||||
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
||||
var defaultClientSecret = "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"
|
||||
// defaultClientSecret 必须通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
||||
var defaultClientSecret string
|
||||
|
||||
func init() {
|
||||
// 从环境变量读取版本号,未设置则使用默认值
|
||||
@ -73,6 +73,10 @@ func GetUserAgent() string {
|
||||
}
|
||||
|
||||
func getClientSecret() (string, error) {
|
||||
if secret := strings.TrimSpace(os.Getenv(AntigravityOAuthClientSecretEnv)); secret != "" {
|
||||
defaultClientSecret = secret
|
||||
return secret, nil
|
||||
}
|
||||
if v := strings.TrimSpace(defaultClientSecret); v != "" {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
19
backend/internal/pkg/antigravity/oauth_runtime_env_test.go
Normal file
19
backend/internal/pkg/antigravity/oauth_runtime_env_test.go
Normal file
@ -0,0 +1,19 @@
|
||||
package antigravity
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGetClientSecret_ReadsRuntimeEnvironment(t *testing.T) {
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = ""
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "runtime-secret")
|
||||
|
||||
secret, err := getClientSecret()
|
||||
if err != nil {
|
||||
t.Fatalf("getClientSecret returned error: %v", err)
|
||||
}
|
||||
if secret != "runtime-secret" {
|
||||
t.Fatalf("unexpected secret: got %q want %q", secret, "runtime-secret")
|
||||
}
|
||||
}
|
||||
@ -35,13 +35,11 @@ const (
|
||||
// GeminiCLIRedirectURI is the redirect URI used by Gemini CLI for Code Assist OAuth.
|
||||
GeminiCLIRedirectURI = "https://codeassist.google.com/authcode"
|
||||
|
||||
// GeminiCLIOAuthClientID/Secret are the public OAuth client credentials used by Google Gemini CLI.
|
||||
// They enable the "login without creating your own OAuth client" experience, but Google may
|
||||
// restrict which scopes are allowed for this client.
|
||||
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
GeminiCLIOAuthClientSecret = "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"
|
||||
// GeminiCLIOAuthClientID is the public OAuth client ID used by Google Gemini CLI.
|
||||
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
|
||||
// GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret.
|
||||
// The secret MUST be provided via this env var — no hardcoded fallback.
|
||||
GeminiCLIOAuthClientSecretEnv = "GEMINI_CLI_OAUTH_CLIENT_SECRET"
|
||||
|
||||
SessionTTL = 30 * time.Minute
|
||||
|
||||
@ -170,11 +170,9 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error
|
||||
// Fall back to built-in Gemini CLI OAuth client when not configured.
|
||||
// SECURITY: This repo does not embed the built-in client secret; it must be provided via env.
|
||||
if effective.ClientID == "" && effective.ClientSecret == "" {
|
||||
secret := strings.TrimSpace(GeminiCLIOAuthClientSecret)
|
||||
if secret == "" {
|
||||
if v, ok := os.LookupEnv(GeminiCLIOAuthClientSecretEnv); ok {
|
||||
secret = strings.TrimSpace(v)
|
||||
}
|
||||
var secret string
|
||||
if v, ok := os.LookupEnv(GeminiCLIOAuthClientSecretEnv); ok {
|
||||
secret = strings.TrimSpace(v)
|
||||
}
|
||||
if secret == "" {
|
||||
return OAuthConfig{}, infraerrors.Newf(http.StatusBadRequest, "GEMINI_CLI_OAUTH_CLIENT_SECRET_MISSING", "built-in Gemini CLI OAuth client_secret is not configured; set %s or provide a custom OAuth client", GeminiCLIOAuthClientSecretEnv)
|
||||
|
||||
@ -408,10 +408,10 @@ func TestBuildAuthorizationURL_WithProjectID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizationURL_UsesBuiltinSecretFallback(t *testing.T) {
|
||||
func TestBuildAuthorizationURL_RequiresBuiltinSecretEnv(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "")
|
||||
|
||||
authURL, err := BuildAuthorizationURL(
|
||||
_, err := BuildAuthorizationURL(
|
||||
OAuthConfig{},
|
||||
"test-state",
|
||||
"test-challenge",
|
||||
@ -419,11 +419,11 @@ func TestBuildAuthorizationURL_UsesBuiltinSecretFallback(t *testing.T) {
|
||||
"",
|
||||
"code_assist",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("BuildAuthorizationURL() 不应报错: %v", err)
|
||||
if err == nil {
|
||||
t.Fatal("BuildAuthorizationURL() 应在未配置内置 secret 环境变量时报错")
|
||||
}
|
||||
if !strings.Contains(authURL, "client_id="+GeminiCLIOAuthClientID) {
|
||||
t.Errorf("应使用内置 Gemini CLI client_id,实际 URL: %s", authURL)
|
||||
if !strings.Contains(err.Error(), GeminiCLIOAuthClientSecretEnv) {
|
||||
t.Fatalf("错误消息应提示缺少 %s: %v", GeminiCLIOAuthClientSecretEnv, err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -686,18 +686,15 @@ func TestEffectiveOAuthConfig_WhitespaceTriming(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_NoEnvSecret(t *testing.T) {
|
||||
func TestEffectiveOAuthConfig_RequiresEnvSecret(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "")
|
||||
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "code_assist")
|
||||
if err != nil {
|
||||
t.Fatalf("不设置环境变量时应回退到内置 secret,实际报错: %v", err)
|
||||
_, err := EffectiveOAuthConfig(OAuthConfig{}, "code_assist")
|
||||
if err == nil {
|
||||
t.Fatal("未配置环境变量时应报错,而不是回退到仓库内置 secret")
|
||||
}
|
||||
if strings.TrimSpace(cfg.ClientSecret) == "" {
|
||||
t.Error("ClientSecret 不应为空")
|
||||
}
|
||||
if cfg.ClientID != GeminiCLIOAuthClientID {
|
||||
t.Errorf("ClientID 应回退为内置客户端 ID,实际: %q", cfg.ClientID)
|
||||
if !strings.Contains(err.Error(), GeminiCLIOAuthClientSecretEnv) {
|
||||
t.Fatalf("错误消息应提示缺少 %s: %v", GeminiCLIOAuthClientSecretEnv, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -45,6 +45,8 @@ import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||
)
|
||||
|
||||
// ============================================================
|
||||
@ -137,9 +139,11 @@ type Instance struct {
|
||||
startedAt time.Time
|
||||
inflight int64 // atomic: current number of concurrent cascade calls
|
||||
modelMapReady int32
|
||||
modelMapHard int32
|
||||
remote bool
|
||||
workerToken string
|
||||
routingKey string
|
||||
modelMapError string
|
||||
}
|
||||
|
||||
// AcquireConcurrency atomically increments the inflight counter.
|
||||
@ -176,6 +180,38 @@ func (i *Instance) SetModelMappingReady(ready bool) {
|
||||
atomic.StoreInt32(&i.modelMapReady, 0)
|
||||
}
|
||||
|
||||
// SetModelMappingUnavailable marks the instance as unable to load model config
|
||||
// with the current token/client combination.
|
||||
func (i *Instance) SetModelMappingUnavailable(reason string) {
|
||||
atomic.StoreInt32(&i.modelMapHard, 1)
|
||||
i.mu.Lock()
|
||||
i.modelMapError = strings.TrimSpace(reason)
|
||||
i.mu.Unlock()
|
||||
}
|
||||
|
||||
// ClearModelMappingUnavailable resets any previously recorded permanent model
|
||||
// mapping failure state.
|
||||
func (i *Instance) ClearModelMappingUnavailable() {
|
||||
atomic.StoreInt32(&i.modelMapHard, 0)
|
||||
i.mu.Lock()
|
||||
i.modelMapError = ""
|
||||
i.mu.Unlock()
|
||||
}
|
||||
|
||||
// HasModelMappingUnavailable reports whether model config loading is currently
|
||||
// known to be incompatible with the account/token.
|
||||
func (i *Instance) HasModelMappingUnavailable() bool {
|
||||
return atomic.LoadInt32(&i.modelMapHard) == 1
|
||||
}
|
||||
|
||||
// ModelMappingUnavailableReason returns the last recorded permanent failure
|
||||
// reason, if any.
|
||||
func (i *Instance) ModelMappingUnavailableReason() string {
|
||||
i.mu.RLock()
|
||||
defer i.mu.RUnlock()
|
||||
return strings.TrimSpace(i.modelMapError)
|
||||
}
|
||||
|
||||
// HasModelMappingReady reports whether this LS instance has completed model
|
||||
// config loading successfully.
|
||||
func (i *Instance) HasModelMappingReady() bool {
|
||||
@ -630,6 +666,16 @@ func (p *Pool) SetAccountToken(accountID, accessToken, refreshToken string, expi
|
||||
ExpiresAt: expiresAt,
|
||||
})
|
||||
}
|
||||
p.mu.RLock()
|
||||
slots := append([]*Instance(nil), p.instances[accountID]...)
|
||||
p.mu.RUnlock()
|
||||
for _, inst := range slots {
|
||||
if inst == nil {
|
||||
continue
|
||||
}
|
||||
inst.SetModelMappingReady(false)
|
||||
inst.ClearModelMappingUnavailable()
|
||||
}
|
||||
}
|
||||
|
||||
// SetAccountModelCredits updates the JS-parity uss-modelCredits state for an account.
|
||||
@ -735,9 +781,9 @@ func (p *Pool) startInstance(accountID string, proxyURL string, replica int) (*I
|
||||
p.logger.Info("LS starting",
|
||||
"account", shortAccountID(accountID),
|
||||
"replica", replica,
|
||||
"proxy_source", rawProxyURL,
|
||||
"proxy_source", logredact.RedactProxyURL(rawProxyURL),
|
||||
"proxy_mode", launchPlan.proxyMode,
|
||||
"effective_proxy", launchPlan.effectiveProxyURL)
|
||||
"effective_proxy", logredact.RedactProxyURL(launchPlan.effectiveProxyURL))
|
||||
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
@ -849,6 +895,14 @@ func (p *Pool) startInstance(accountID string, proxyURL string, replica int) (*I
|
||||
p.logger.Info("model mapping loaded", "account", shortAccountID(accountID), "replica", replica, "attempt", attempt)
|
||||
return
|
||||
}
|
||||
if inst.HasModelMappingUnavailable() {
|
||||
p.logger.Warn("model mapping unavailable",
|
||||
"account", shortAccountID(accountID),
|
||||
"replica", replica,
|
||||
"attempt", attempt,
|
||||
"reason", truncate(inst.ModelMappingUnavailableReason(), 200))
|
||||
return
|
||||
}
|
||||
p.logger.Warn("model mapping not loaded, retrying", "account", shortAccountID(accountID), "replica", replica, "attempt", attempt)
|
||||
time.Sleep(time.Duration(attempt*10) * time.Second)
|
||||
}
|
||||
|
||||
@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
@ -197,6 +198,35 @@ func TestCurrentLSStrategy(t *testing.T) {
|
||||
require.Equal(t, LSStrategyDirect, CurrentLSStrategy())
|
||||
}
|
||||
|
||||
func TestIsPermanentModelMappingError(t *testing.T) {
|
||||
require.True(t, isPermanentModelMappingError(errors.New(`oauth2: "unauthorized_client" "Unauthorized"`)))
|
||||
require.False(t, isPermanentModelMappingError(errors.New("context deadline exceeded")))
|
||||
}
|
||||
|
||||
func TestPoolSetAccountTokenClearsModelMappingUnavailable(t *testing.T) {
|
||||
pool := &Pool{
|
||||
instances: map[string][]*Instance{
|
||||
"9": {
|
||||
{AccountID: "9", Replica: 0},
|
||||
},
|
||||
},
|
||||
}
|
||||
inst := pool.instances["9"][0]
|
||||
inst.SetModelMappingReady(true)
|
||||
inst.SetModelMappingUnavailable(`oauth2: "unauthorized_client" "Unauthorized"`)
|
||||
|
||||
pool.SetAccountToken("9", "ya29.new", "refresh", time.Now().Add(time.Hour))
|
||||
|
||||
require.False(t, inst.HasModelMappingReady())
|
||||
require.False(t, inst.HasModelMappingUnavailable())
|
||||
require.Empty(t, inst.ModelMappingUnavailableReason())
|
||||
}
|
||||
|
||||
func TestShouldFallbackDirectForModelMappingUnavailable(t *testing.T) {
|
||||
require.True(t, shouldFallbackDirect(fmt.Errorf("%w: oauth2 unauthorized_client", errLSModelMapDenied)))
|
||||
require.False(t, shouldFallbackDirect(errLSModelMapPending))
|
||||
}
|
||||
|
||||
func TestParseLSReplicaCountDefaultAndEnv(t *testing.T) {
|
||||
t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "")
|
||||
require.Equal(t, 5, parseLSReplicaCount())
|
||||
|
||||
@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
func TestBuildProxychainsConfigIncludesAuthAndLocalBypass(t *testing.T) {
|
||||
proxyURL, err := url.Parse("socks5h://gostuser:fastapipwd@216.167.85.31:1080")
|
||||
proxyURL, err := url.Parse("socks5h://testuser:testpass@192.0.2.1:1080")
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := buildProxychainsConfig(proxyURL)
|
||||
@ -18,7 +18,7 @@ func TestBuildProxychainsConfigIncludesAuthAndLocalBypass(t *testing.T) {
|
||||
require.Contains(t, cfg, "localnet 127.0.0.0/255.0.0.0\n")
|
||||
require.Contains(t, cfg, "localnet ::1/128\n")
|
||||
require.Contains(t, cfg, "[ProxyList]\n")
|
||||
require.Contains(t, cfg, "socks5 216.167.85.31 1080 gostuser fastapipwd\n")
|
||||
require.Contains(t, cfg, "socks5 192.0.2.1 1080 testuser testpass\n")
|
||||
}
|
||||
|
||||
func TestBuildProxychainsConfigRejectsWhitespaceCredentials(t *testing.T) {
|
||||
|
||||
@ -71,6 +71,7 @@ var (
|
||||
errLSTranscriptDrift = errors.New("request transcript diverged from cached cascade session")
|
||||
errLSQuotaExhausted = errors.New("ls cascade returned quota exhausted")
|
||||
errLSModelMapPending = errors.New("model mapping not ready")
|
||||
errLSModelMapDenied = errors.New("model mapping unavailable")
|
||||
)
|
||||
|
||||
// IsLSQuotaExhaustedError reports whether err originated from an LS cascade
|
||||
@ -98,6 +99,20 @@ func LSQuotaExhaustedMessage(err error) string {
|
||||
return msg
|
||||
}
|
||||
|
||||
func isPermanentModelMappingError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(strings.ToLower(err.Error()), "unauthorized_client")
|
||||
}
|
||||
|
||||
func modelMappingDeniedReason(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
return truncate(strings.TrimSpace(err.Error()), 200)
|
||||
}
|
||||
|
||||
type cascadeSessionState struct {
|
||||
CascadeID string
|
||||
SystemText string
|
||||
@ -269,7 +284,9 @@ func (u *LSPoolUpstream) doViaLS(req *http.Request, body []byte, accountID int64
|
||||
}
|
||||
|
||||
func shouldFallbackDirect(err error) bool {
|
||||
return errors.Is(err, errLSRouteDirect) || errors.Is(err, errLSTranscriptDrift)
|
||||
return errors.Is(err, errLSRouteDirect) ||
|
||||
errors.Is(err, errLSTranscriptDrift) ||
|
||||
errors.Is(err, errLSModelMapDenied)
|
||||
}
|
||||
|
||||
func (u *LSPoolUpstream) forwardDirectWithKeepalive(req *http.Request, body []byte, accountKey string, accountID int64, proxyURL string) (*http.Response, error) {
|
||||
@ -413,6 +430,11 @@ func (u *LSPoolUpstream) forwardChatViaLS(req *http.Request, body []byte, parsed
|
||||
}
|
||||
trace.GetOrCreateDuration = time.Since(getOrCreateStartedAt)
|
||||
trace.Replica = inst.Replica
|
||||
if inst.HasModelMappingUnavailable() {
|
||||
reason := inst.ModelMappingUnavailableReason()
|
||||
u.logTraceSummary(slog.LevelInfo, "[LS-POOL] model mapping unavailable, routing direct", trace, "reason", reason)
|
||||
return nil, fmt.Errorf("%w: %s", errLSModelMapDenied, reason)
|
||||
}
|
||||
if !inst.HasModelMappingReady() {
|
||||
u.logTraceSummary(slog.LevelInfo, "[LS-POOL] model mapping pending, routing direct", trace)
|
||||
return nil, errLSModelMapPending
|
||||
@ -1391,6 +1413,18 @@ func RefreshModelMapping(inst *Instance) bool {
|
||||
resp, err := inst.CallUnaryJSON(ctx, LSService, "GetCascadeModelConfigData", map[string]any{})
|
||||
if err != nil {
|
||||
inst.SetModelMappingReady(false)
|
||||
if isPermanentModelMappingError(err) {
|
||||
reason := modelMappingDeniedReason(err)
|
||||
inst.SetModelMappingUnavailable(reason)
|
||||
slog.Warn("[LS-POOL] Model mapping unavailable",
|
||||
"account", inst.AccountID,
|
||||
"replica", inst.Replica,
|
||||
"address", inst.Address,
|
||||
"elapsed", time.Since(startedAt).Truncate(time.Millisecond),
|
||||
"reason", reason)
|
||||
return false
|
||||
}
|
||||
inst.ClearModelMappingUnavailable()
|
||||
slog.Warn("[LS-POOL] Failed to get model config",
|
||||
"account", inst.AccountID,
|
||||
"replica", inst.Replica,
|
||||
@ -1408,6 +1442,7 @@ func RefreshModelMapping(inst *Instance) bool {
|
||||
}
|
||||
if err := json.Unmarshal(resp, &data); err != nil {
|
||||
inst.SetModelMappingReady(false)
|
||||
inst.ClearModelMappingUnavailable()
|
||||
return false
|
||||
}
|
||||
|
||||
@ -1440,6 +1475,7 @@ func RefreshModelMapping(inst *Instance) bool {
|
||||
dynamicModelMap = newMap
|
||||
dynamicModelMapMu.Unlock()
|
||||
inst.SetModelMappingReady(true)
|
||||
inst.ClearModelMappingUnavailable()
|
||||
slog.Info("[LS-POOL] Model mapping refreshed",
|
||||
"account", inst.AccountID,
|
||||
"replica", inst.Replica,
|
||||
@ -1449,6 +1485,7 @@ func RefreshModelMapping(inst *Instance) bool {
|
||||
return true
|
||||
}
|
||||
inst.SetModelMappingReady(false)
|
||||
inst.ClearModelMappingUnavailable()
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@ -501,6 +501,11 @@ func (m *workerManager) waitForWorkerReady(handle *workerHandle, routingKey stri
|
||||
values.Set("routing_key", routingKey)
|
||||
}
|
||||
|
||||
var (
|
||||
lastStatus int
|
||||
lastBody string
|
||||
)
|
||||
|
||||
for {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, workerURL(handle, "/readyz", values), nil)
|
||||
if err != nil {
|
||||
@ -511,22 +516,48 @@ func (m *workerManager) waitForWorkerReady(handle *workerHandle, routingKey stri
|
||||
if err == nil {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
lastStatus = resp.StatusCode
|
||||
lastBody = truncate(string(body), 200)
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
if len(body) > 0 {
|
||||
if isWorkerModelMappingUnavailable(resp.StatusCode, lastBody) {
|
||||
return fmt.Errorf("%w: worker %s %s", errLSModelMapDenied, handle.Container, strings.TrimSpace(lastBody))
|
||||
}
|
||||
if len(body) > 0 && shouldWarnWorkerNotReady(resp.StatusCode, lastBody) {
|
||||
m.logger.Warn("ls worker not ready yet", "container", handle.Container, "status", resp.StatusCode, "body", truncate(string(body), 200))
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if lastStatus > 0 || lastBody != "" {
|
||||
return fmt.Errorf("worker %s not ready for routing key %q (last_status=%d last_body=%q): %w", handle.Container, routingKey, lastStatus, lastBody, ctx.Err())
|
||||
}
|
||||
return fmt.Errorf("worker %s not ready for routing key %q: %w", handle.Container, routingKey, ctx.Err())
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func shouldWarnWorkerNotReady(status int, body string) bool {
|
||||
if status == http.StatusServiceUnavailable {
|
||||
normalized := strings.ToLower(strings.TrimSpace(body))
|
||||
if strings.Contains(normalized, "model mapping not ready") {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isWorkerModelMappingUnavailable(status int, body string) bool {
|
||||
if status != http.StatusServiceUnavailable {
|
||||
return false
|
||||
}
|
||||
normalized := strings.ToLower(strings.TrimSpace(body))
|
||||
return strings.Contains(normalized, errLSModelMapDenied.Error())
|
||||
}
|
||||
|
||||
func (m *workerManager) syncWorkerState(handle *workerHandle, state *workerAccountState) error {
|
||||
if state == nil {
|
||||
return fmt.Errorf("ls worker state is nil")
|
||||
|
||||
@ -264,3 +264,72 @@ func TestFakeDockerClientImplementsFilterAwareList(t *testing.T) {
|
||||
_, err := fakeDocker.ContainerList(context.Background(), container.ListOptions{Filters: filters.NewArgs()})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestShouldWarnWorkerNotReadySuppressesModelMappingPending(t *testing.T) {
|
||||
require.False(t, shouldWarnWorkerNotReady(http.StatusServiceUnavailable, "worker model mapping not ready for replica 0"))
|
||||
require.True(t, shouldWarnWorkerNotReady(http.StatusServiceUnavailable, "worker access token not configured"))
|
||||
require.True(t, shouldWarnWorkerNotReady(http.StatusBadGateway, "upstream failed"))
|
||||
}
|
||||
|
||||
func TestWorkerManagerWaitForWorkerReadyStopsOnModelMappingUnavailable(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/readyz", r.URL.Path)
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
_, _ = w.Write([]byte(`model mapping unavailable for replica 0: oauth2: "unauthorized_client" "Unauthorized"`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
manager, err := newWorkerManager(workerManagerConfig{
|
||||
Image: "worker:latest",
|
||||
Network: "sub2api-network",
|
||||
DockerSocket: "unix:///var/run/docker.sock",
|
||||
IdleTTL: time.Minute,
|
||||
MaxActive: 1,
|
||||
StartupTimeout: time.Second,
|
||||
RequestTimeout: time.Second,
|
||||
}, &fakeDockerClient{})
|
||||
require.NoError(t, err)
|
||||
defer manager.Close()
|
||||
|
||||
handle := &workerHandle{
|
||||
Container: "sub2api-ls-test",
|
||||
Address: strings.TrimPrefix(server.URL, "http://"),
|
||||
AuthToken: "worker-token",
|
||||
}
|
||||
|
||||
err = manager.waitForWorkerReady(handle, "")
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, errLSModelMapDenied)
|
||||
}
|
||||
|
||||
func TestWorkerManagerWaitForWorkerReadyIncludesLastBodyOnTimeout(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/readyz", r.URL.Path)
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
_, _ = w.Write([]byte("worker model mapping not ready for replica 0\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
manager, err := newWorkerManager(workerManagerConfig{
|
||||
Image: "worker:latest",
|
||||
Network: "sub2api-network",
|
||||
DockerSocket: "unix:///var/run/docker.sock",
|
||||
IdleTTL: time.Minute,
|
||||
MaxActive: 1,
|
||||
StartupTimeout: 100 * time.Millisecond,
|
||||
RequestTimeout: time.Second,
|
||||
}, &fakeDockerClient{})
|
||||
require.NoError(t, err)
|
||||
defer manager.Close()
|
||||
|
||||
handle := &workerHandle{
|
||||
Container: "sub2api-ls-test",
|
||||
Address: strings.TrimPrefix(server.URL, "http://"),
|
||||
AuthToken: "worker-token",
|
||||
}
|
||||
|
||||
err = manager.waitForWorkerReady(handle, "")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), `last_status=503`)
|
||||
require.Contains(t, err.Error(), `last_body="worker model mapping not ready for replica 0`)
|
||||
}
|
||||
|
||||
@ -308,6 +308,9 @@ func (s *WorkerServer) ensureReady(ctx context.Context, routingKey string) (*Ins
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if inst.HasModelMappingUnavailable() {
|
||||
return nil, fmt.Errorf("%w for replica %d: %s", errLSModelMapDenied, inst.Replica, inst.ModelMappingUnavailableReason())
|
||||
}
|
||||
if inst.HasModelMappingReady() {
|
||||
return inst, nil
|
||||
}
|
||||
@ -316,6 +319,9 @@ func (s *WorkerServer) ensureReady(ctx context.Context, routingKey string) (*Ins
|
||||
defer cancel()
|
||||
_ = modelCtx
|
||||
if !RefreshModelMapping(inst) {
|
||||
if inst.HasModelMappingUnavailable() {
|
||||
return nil, fmt.Errorf("%w for replica %d: %s", errLSModelMapDenied, inst.Replica, inst.ModelMappingUnavailableReason())
|
||||
}
|
||||
return nil, fmt.Errorf("worker model mapping not ready for replica %d", inst.Replica)
|
||||
}
|
||||
return inst, nil
|
||||
|
||||
@ -13,9 +13,12 @@ import (
|
||||
)
|
||||
|
||||
// allowedSchemes 代理协议白名单
|
||||
// 注意: https 代理已被移除。当前实现(Go dialer.go 和 Node proxy.js)
|
||||
// 对 https:// 代理仅做 TCP 连接后发明文 CONNECT,不建立外层 TLS,
|
||||
// 导致 Proxy-Authorization 凭据在首跳明文传输。
|
||||
// 若需 https 代理支持,须先在 dialer.go 和 proxy.js 中实现 TLS-to-proxy。
|
||||
var allowedSchemes = map[string]bool{
|
||||
"http": true,
|
||||
"https": true,
|
||||
"socks5": true,
|
||||
"socks5h": true,
|
||||
}
|
||||
@ -31,7 +34,7 @@ var allowedSchemes = map[string]bool{
|
||||
// - TrimSpace 后为空视为直连
|
||||
// - url.Parse 失败返回 error(不含原始 URL,防凭据泄露)
|
||||
// - Host 为空返回 error(用 Redacted() 脱敏)
|
||||
// - Scheme 必须为 http/https/socks5/socks5h
|
||||
// - Scheme 必须为 http/socks5/socks5h(https 不支持,因 CONNECT 明文传输)
|
||||
// - socks5:// 自动升级为 socks5h://(确保 DNS 由代理端解析,防止 DNS 泄漏)
|
||||
func Parse(raw string) (trimmed string, parsed *url.URL, err error) {
|
||||
trimmed = strings.TrimSpace(raw)
|
||||
@ -51,7 +54,10 @@ func Parse(raw string) (trimmed string, parsed *url.URL, err error) {
|
||||
|
||||
scheme := strings.ToLower(parsed.Scheme)
|
||||
if !allowedSchemes[scheme] {
|
||||
return "", nil, fmt.Errorf("unsupported proxy scheme %q (allowed: http, https, socks5, socks5h)", scheme)
|
||||
if scheme == "https" {
|
||||
return "", nil, fmt.Errorf("https proxy scheme is not supported: current implementation sends CONNECT in plaintext (use http:// or socks5:// instead)")
|
||||
}
|
||||
return "", nil, fmt.Errorf("unsupported proxy scheme %q (allowed: http, socks5, socks5h)", scheme)
|
||||
}
|
||||
|
||||
// 自动升级 socks5 → socks5h,确保 DNS 由代理端解析,防止 DNS 泄漏。
|
||||
|
||||
@ -47,13 +47,13 @@ func TestParse_有效HTTP代理(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_有效HTTPS代理(t *testing.T) {
|
||||
_, parsed, err := Parse("https://proxy.example.com:443")
|
||||
if err != nil {
|
||||
t.Fatalf("有效 HTTPS 代理应成功: %v", err)
|
||||
func TestParse_HTTPS代理被拒绝(t *testing.T) {
|
||||
_, _, err := Parse("https://proxy.example.com:443")
|
||||
if err == nil {
|
||||
t.Fatal("https 代理应返回错误(当前实现不支持 TLS-to-proxy)")
|
||||
}
|
||||
if parsed.Scheme != "https" {
|
||||
t.Errorf("Scheme 不匹配: got %q", parsed.Scheme)
|
||||
if !strings.Contains(err.Error(), "https proxy scheme is not supported") {
|
||||
t.Errorf("错误信息应包含 'https proxy scheme is not supported': got %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -24,6 +24,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
)
|
||||
|
||||
@ -205,7 +206,7 @@ func (s *httpUpstreamService) DoWithTLS(req *http.Request, proxyURL string, acco
|
||||
}
|
||||
proxyInfo := "direct"
|
||||
if proxyURL != "" {
|
||||
proxyInfo = proxyURL
|
||||
proxyInfo = logredact.RedactProxyURL(proxyURL)
|
||||
}
|
||||
slog.Debug("tls_fingerprint_enabled", "account_id", accountID, "target", targetHost, "proxy", proxyInfo, "profile", profile.Name)
|
||||
|
||||
@ -302,7 +303,7 @@ func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID i
|
||||
}
|
||||
|
||||
// 创建带 TLS 指纹的 Transport
|
||||
slog.Debug("tls_fingerprint_creating_new_client", "account_id", accountID, "cache_key", cacheKey, "proxy", proxyKey)
|
||||
slog.Debug("tls_fingerprint_creating_new_client", "account_id", accountID, "cache_key", cacheKey, "proxy", logredact.RedactProxyURL(proxyKey))
|
||||
settings := s.resolvePoolSettings(isolation, accountConcurrency)
|
||||
transport, err := buildUpstreamTransportWithTLSFingerprint(settings, parsedProxy, profile)
|
||||
if err != nil {
|
||||
|
||||
@ -53,8 +53,9 @@ const migrationsLockRetryInterval = 500 * time.Millisecond
|
||||
const nonTransactionalMigrationSuffix = "_notx.sql"
|
||||
|
||||
type migrationChecksumCompatibilityRule struct {
|
||||
fileChecksum string
|
||||
acceptedDBChecksum map[string]struct{}
|
||||
fileChecksum string
|
||||
acceptedFileChecksums map[string]struct{}
|
||||
acceptedDBChecksum map[string]struct{}
|
||||
}
|
||||
|
||||
// migrationChecksumCompatibilityRules 仅用于兼容历史上误修改过的迁移文件 checksum。
|
||||
@ -73,6 +74,15 @@ var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibil
|
||||
"222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3": {},
|
||||
},
|
||||
},
|
||||
"082_create_gateway_debug_logs.sql": {
|
||||
fileChecksum: "b740d7274afbd37d4448e3a3a9aa1fb562181ded5d0319e47a6444187d22f6b1",
|
||||
acceptedFileChecksums: map[string]struct{}{
|
||||
"bf5348a22cf1f27c852096beb3583b67ec43819af82b2f9664397a5638e5b386": {},
|
||||
},
|
||||
acceptedDBChecksum: map[string]struct{}{
|
||||
"d00c2e69711cc0c006b0234566101d8639ba08db77283558f07e2ba412ec177d": {},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
|
||||
@ -328,7 +338,9 @@ func isMigrationChecksumCompatible(name, dbChecksum, fileChecksum string) bool {
|
||||
return false
|
||||
}
|
||||
if rule.fileChecksum != fileChecksum {
|
||||
return false
|
||||
if _, ok := rule.acceptedFileChecksums[fileChecksum]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
_, ok = rule.acceptedDBChecksum[dbChecksum]
|
||||
return ok
|
||||
|
||||
@ -92,6 +92,11 @@ func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) {
|
||||
}
|
||||
require.NotEmpty(t, accepted)
|
||||
require.True(t, isMigrationChecksumCompatible(name, accepted, rule.fileChecksum))
|
||||
|
||||
for alternateFileChecksum := range rule.acceptedFileChecksums {
|
||||
require.True(t, isMigrationChecksumCompatible(name, accepted, alternateFileChecksum))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureAtlasBaselineAligned(t *testing.T) {
|
||||
|
||||
@ -103,8 +103,15 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
||||
defer cancel()
|
||||
result, err := p.refreshAPI.RefreshIfNeeded(refreshCtx, account, p.executor, antigravityTokenRefreshSkew)
|
||||
if err != nil {
|
||||
// 标记账号临时不可调度,避免后续请求继续命中
|
||||
p.markTempUnschedulable(account, err)
|
||||
// 全局 OAuth 配置缺失不应污染账号状态;账号级失败才标记 temp unschedulable。
|
||||
if shouldMarkTempUnschedulableForRefreshError(err) {
|
||||
p.markTempUnschedulable(account, err)
|
||||
} else {
|
||||
slog.Warn("antigravity_token_provider.temp_unschedulable_skipped",
|
||||
"account_id", account.ID,
|
||||
"reason", err.Error(),
|
||||
)
|
||||
}
|
||||
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
|
||||
return "", err
|
||||
}
|
||||
@ -226,6 +233,23 @@ func (p *AntigravityTokenProvider) markTempUnschedulable(account *Account, refre
|
||||
}
|
||||
}
|
||||
|
||||
func shouldMarkTempUnschedulableForRefreshError(refreshErr error) bool {
|
||||
if refreshErr == nil {
|
||||
return false
|
||||
}
|
||||
msg := strings.ToLower(strings.TrimSpace(refreshErr.Error()))
|
||||
if msg == "" {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(msg, "antigravity_oauth_client_secret_missing") {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(msg, "missing antigravity oauth client_secret") {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *AntigravityTokenProvider) markBackfillAttempted(accountID int64) {
|
||||
p.backfillCooldown.Store(accountID, time.Now())
|
||||
}
|
||||
|
||||
@ -0,0 +1,20 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestShouldMarkTempUnschedulableForRefreshError(t *testing.T) {
|
||||
t.Run("skip global oauth client secret missing", func(t *testing.T) {
|
||||
err := errors.New(`token 刷新失败 (重试后): error: code=400 reason="ANTIGRAVITY_OAUTH_CLIENT_SECRET_MISSING" message="missing antigravity oauth client_secret; set ANTIGRAVITY_OAUTH_CLIENT_SECRET" metadata=map[]`)
|
||||
require.False(t, shouldMarkTempUnschedulableForRefreshError(err))
|
||||
})
|
||||
|
||||
t.Run("allow account specific refresh error", func(t *testing.T) {
|
||||
err := errors.New("token 刷新失败 (重试后): invalid_grant")
|
||||
require.True(t, shouldMarkTempUnschedulableForRefreshError(err))
|
||||
})
|
||||
}
|
||||
@ -16,6 +16,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -463,7 +464,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ProxyURL: %s", proxyURL)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ProxyURL: %s", logredact.RedactProxyURL(proxyURL))
|
||||
|
||||
redirectURI := session.RedirectURI
|
||||
|
||||
|
||||
225
backend/internal/service/lspool_bootstrap_service.go
Normal file
225
backend/internal/service/lspool_bootstrap_service.go
Normal file
@ -0,0 +1,225 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/lspool"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultLSPoolBootstrapConcurrency = 4
|
||||
)
|
||||
|
||||
type lsBootstrapAccountReader interface {
|
||||
ListByPlatform(ctx context.Context, platform string) ([]Account, error)
|
||||
}
|
||||
|
||||
// LSPoolBootstrapService pre-creates LS workers for eligible Antigravity accounts on startup.
|
||||
type LSPoolBootstrapService struct {
|
||||
accountReader lsBootstrapAccountReader
|
||||
backend lspool.Backend
|
||||
cfg *config.Config
|
||||
logger *slog.Logger
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
once sync.Once
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewLSPoolBootstrapService(accountReader lsBootstrapAccountReader, backend lspool.Backend, cfg *config.Config) *LSPoolBootstrapService {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &LSPoolBootstrapService{
|
||||
accountReader: accountReader,
|
||||
backend: backend,
|
||||
cfg: cfg,
|
||||
logger: slog.Default().With("component", "service.lspool_bootstrap"),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// ProvideLSPoolBootstrapService creates and starts the LS pool bootstrap worker.
|
||||
func ProvideLSPoolBootstrapService(accountRepo AccountRepository, cfg *config.Config) *LSPoolBootstrapService {
|
||||
svc := NewLSPoolBootstrapService(accountRepo, lspool.GlobalPool(cfg), cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
func (s *LSPoolBootstrapService) Start() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.once.Do(func() {
|
||||
if s.backend == nil {
|
||||
if lspool.IsLSModeEnabled() {
|
||||
s.logger.Warn("startup bootstrap skipped: ls backend unavailable")
|
||||
}
|
||||
return
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.bootstrap(s.ctx)
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
func (s *LSPoolBootstrapService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.cancel()
|
||||
s.wg.Wait()
|
||||
}
|
||||
|
||||
func (s *LSPoolBootstrapService) bootstrap(ctx context.Context) {
|
||||
if s.backend == nil || s.accountReader == nil {
|
||||
return
|
||||
}
|
||||
|
||||
accounts, err := s.accountReader.ListByPlatform(ctx, PlatformAntigravity)
|
||||
if err != nil {
|
||||
s.logger.Warn("load antigravity accounts for ls bootstrap failed", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
candidates := make([]Account, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
if shouldBootstrapLSPoolAccount(&accounts[i], now) {
|
||||
candidates = append(candidates, accounts[i])
|
||||
}
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
s.logger.Info("startup bootstrap skipped: no eligible antigravity accounts")
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Info("starting ls worker bootstrap",
|
||||
"accounts_total", len(accounts),
|
||||
"accounts_eligible", len(candidates),
|
||||
"concurrency", s.bootstrapConcurrency())
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
started int
|
||||
failed int
|
||||
)
|
||||
sem := make(chan struct{}, s.bootstrapConcurrency())
|
||||
var wg sync.WaitGroup
|
||||
|
||||
loop:
|
||||
for i := range candidates {
|
||||
account := candidates[i]
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
break loop
|
||||
case sem <- struct{}{}:
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(account Account) {
|
||||
defer wg.Done()
|
||||
defer func() { <-sem }()
|
||||
|
||||
if err := s.bootstrapAccount(&account); err != nil {
|
||||
mu.Lock()
|
||||
failed++
|
||||
mu.Unlock()
|
||||
s.logger.Warn("bootstrap ls worker failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
started++
|
||||
mu.Unlock()
|
||||
s.logger.Info("bootstrap ls worker ready", "account_id", account.ID)
|
||||
}(account)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
s.logger.Info("ls worker bootstrap completed",
|
||||
"accounts_total", len(accounts),
|
||||
"accounts_eligible", len(candidates),
|
||||
"workers_ready", started,
|
||||
"workers_failed", failed,
|
||||
"canceled", ctx.Err() != nil)
|
||||
}
|
||||
|
||||
func (s *LSPoolBootstrapService) bootstrapAccount(account *Account) error {
|
||||
if s.backend == nil {
|
||||
return fmt.Errorf("ls backend unavailable")
|
||||
}
|
||||
if account == nil {
|
||||
return fmt.Errorf("account is nil")
|
||||
}
|
||||
|
||||
accountKey := strconv.FormatInt(account.ID, 10)
|
||||
accessToken := strings.TrimSpace(account.GetCredential("access_token"))
|
||||
if accessToken == "" {
|
||||
return fmt.Errorf("missing access token")
|
||||
}
|
||||
refreshToken := strings.TrimSpace(account.GetCredential("refresh_token"))
|
||||
|
||||
expiresAt := time.Time{}
|
||||
if ts := account.GetCredentialAsTime("expires_at"); ts != nil {
|
||||
expiresAt = ts.UTC()
|
||||
}
|
||||
|
||||
s.backend.SetAccountToken(accountKey, accessToken, refreshToken, expiresAt)
|
||||
availableCredits, minimumCreditAmount := resolveLSPoolModelCreditsState(account)
|
||||
s.backend.SetAccountModelCredits(accountKey, account.IsOveragesEnabled(), availableCredits, minimumCreditAmount)
|
||||
|
||||
proxyURL := ""
|
||||
if account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
if _, err := s.backend.GetOrCreate(accountKey, "", proxyURL); err != nil {
|
||||
return fmt.Errorf("get or create ls worker: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *LSPoolBootstrapService) bootstrapConcurrency() int {
|
||||
parallelism := defaultLSPoolBootstrapConcurrency
|
||||
if s.cfg != nil && s.cfg.Gateway.AntigravityLSWorker.MaxActive > 0 && s.cfg.Gateway.AntigravityLSWorker.MaxActive < parallelism {
|
||||
parallelism = s.cfg.Gateway.AntigravityLSWorker.MaxActive
|
||||
}
|
||||
if parallelism < 1 {
|
||||
return 1
|
||||
}
|
||||
return parallelism
|
||||
}
|
||||
|
||||
func shouldBootstrapLSPoolAccount(account *Account, now time.Time) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if account.Platform != PlatformAntigravity {
|
||||
return false
|
||||
}
|
||||
if account.Type != AccountTypeOAuth {
|
||||
return false
|
||||
}
|
||||
if account.Status != StatusActive || !account.Schedulable {
|
||||
return false
|
||||
}
|
||||
if account.AutoPauseOnExpired && account.ExpiresAt != nil && !now.Before(*account.ExpiresAt) {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(account.GetCredential("access_token")) == "" {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(account.GetCredential("project_id")) != ""
|
||||
}
|
||||
262
backend/internal/service/lspool_bootstrap_service_test.go
Normal file
262
backend/internal/service/lspool_bootstrap_service_test.go
Normal file
@ -0,0 +1,262 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/lspool"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeLSBootstrapAccountReader struct {
|
||||
mu sync.Mutex
|
||||
accounts []Account
|
||||
err error
|
||||
platforms []string
|
||||
}
|
||||
|
||||
func (f *fakeLSBootstrapAccountReader) ListByPlatform(_ context.Context, platform string) ([]Account, error) {
|
||||
f.mu.Lock()
|
||||
f.platforms = append(f.platforms, platform)
|
||||
accounts := append([]Account(nil), f.accounts...)
|
||||
err := f.err
|
||||
f.mu.Unlock()
|
||||
return accounts, err
|
||||
}
|
||||
|
||||
type fakeLSPoolBackend struct {
|
||||
mu sync.Mutex
|
||||
tokenCalls map[string]fakeLSPoolTokenCall
|
||||
creditCalls map[string]fakeLSPoolCreditCall
|
||||
getCalls []fakeLSPoolGetCall
|
||||
getErrs map[string]error
|
||||
}
|
||||
|
||||
type fakeLSPoolTokenCall struct {
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type fakeLSPoolCreditCall struct {
|
||||
UseAICredits bool
|
||||
AvailableCredits *int32
|
||||
MinimumCreditAmount *int32
|
||||
}
|
||||
|
||||
type fakeLSPoolGetCall struct {
|
||||
AccountID string
|
||||
RoutingKey string
|
||||
ProxyURL string
|
||||
}
|
||||
|
||||
func newFakeLSPoolBackend() *fakeLSPoolBackend {
|
||||
return &fakeLSPoolBackend{
|
||||
tokenCalls: make(map[string]fakeLSPoolTokenCall),
|
||||
creditCalls: make(map[string]fakeLSPoolCreditCall),
|
||||
getErrs: make(map[string]error),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeLSPoolBackend) GetOrCreate(accountID, routingKey string, proxyURL ...string) (*lspool.Instance, error) {
|
||||
rawProxy := ""
|
||||
if len(proxyURL) > 0 {
|
||||
rawProxy = proxyURL[0]
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.getCalls = append(f.getCalls, fakeLSPoolGetCall{
|
||||
AccountID: accountID,
|
||||
RoutingKey: routingKey,
|
||||
ProxyURL: rawProxy,
|
||||
})
|
||||
if err := f.getErrs[accountID]; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &lspool.Instance{AccountID: accountID}, nil
|
||||
}
|
||||
|
||||
func (f *fakeLSPoolBackend) SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.tokenCalls[accountID] = fakeLSPoolTokenCall{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeLSPoolBackend) SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.creditCalls[accountID] = fakeLSPoolCreditCall{
|
||||
UseAICredits: useAICredits,
|
||||
AvailableCredits: copyInt32Ptr(availableCredits),
|
||||
MinimumCreditAmount: copyInt32Ptr(minimumCreditAmountForUsage),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeLSPoolBackend) Stats() map[string]any { return nil }
|
||||
|
||||
func (f *fakeLSPoolBackend) Close() {}
|
||||
|
||||
func copyInt32Ptr(v *int32) *int32 {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
cp := *v
|
||||
return &cp
|
||||
}
|
||||
|
||||
func TestLSPoolBootstrapServiceBootstrapEligibleAccounts(t *testing.T) {
|
||||
expiresAt := time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second)
|
||||
expiredAt := time.Now().Add(-2 * time.Hour)
|
||||
reader := &fakeLSBootstrapAccountReader{
|
||||
accounts: []Account{
|
||||
{
|
||||
ID: 101,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token-101",
|
||||
"refresh_token": "refresh-101",
|
||||
"expires_at": expiresAt.Format(time.RFC3339),
|
||||
"project_id": "proj-101",
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
"ai_credits": []any{
|
||||
map[string]any{
|
||||
"credit_type": "GOOGLE_ONE_AI",
|
||||
"amount": 120,
|
||||
"minimum_balance": 55,
|
||||
},
|
||||
},
|
||||
},
|
||||
Proxy: &Proxy{
|
||||
Protocol: "socks5h",
|
||||
Host: "127.0.0.1",
|
||||
Port: 1080,
|
||||
Username: "alice",
|
||||
Password: "secret",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: 102,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: false,
|
||||
Credentials: map[string]any{"access_token": "token-102", "project_id": "proj-102"},
|
||||
},
|
||||
{
|
||||
ID: 103,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"access_token": "token-103"},
|
||||
},
|
||||
{
|
||||
ID: 104,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
AutoPauseOnExpired: true,
|
||||
ExpiresAt: &expiredAt,
|
||||
Credentials: map[string]any{"access_token": "token-104", "project_id": "proj-104"},
|
||||
},
|
||||
{
|
||||
ID: 106,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeUpstream,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"access_token": "token-106", "project_id": "proj-106"},
|
||||
},
|
||||
{
|
||||
ID: 105,
|
||||
Platform: PlatformOpenAI,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"access_token": "token-105"},
|
||||
},
|
||||
},
|
||||
}
|
||||
backend := newFakeLSPoolBackend()
|
||||
svc := NewLSPoolBootstrapService(reader, backend, &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
AntigravityLSWorker: config.GatewayAntigravityLSWorkerConfig{MaxActive: 3},
|
||||
},
|
||||
})
|
||||
|
||||
svc.bootstrap(context.Background())
|
||||
|
||||
require.Equal(t, []string{PlatformAntigravity}, reader.platforms)
|
||||
|
||||
require.Len(t, backend.getCalls, 1)
|
||||
require.Equal(t, fakeLSPoolGetCall{
|
||||
AccountID: "101",
|
||||
RoutingKey: "",
|
||||
ProxyURL: "socks5h://alice:secret@127.0.0.1:1080",
|
||||
}, backend.getCalls[0])
|
||||
|
||||
tokenCall, ok := backend.tokenCalls["101"]
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "token-101", tokenCall.AccessToken)
|
||||
require.Equal(t, "refresh-101", tokenCall.RefreshToken)
|
||||
require.Equal(t, expiresAt, tokenCall.ExpiresAt)
|
||||
|
||||
creditCall, ok := backend.creditCalls["101"]
|
||||
require.True(t, ok)
|
||||
require.True(t, creditCall.UseAICredits)
|
||||
require.NotNil(t, creditCall.AvailableCredits)
|
||||
require.Equal(t, int32(120), *creditCall.AvailableCredits)
|
||||
require.NotNil(t, creditCall.MinimumCreditAmount)
|
||||
require.Equal(t, int32(55), *creditCall.MinimumCreditAmount)
|
||||
|
||||
require.NotContains(t, backend.tokenCalls, "102")
|
||||
require.NotContains(t, backend.tokenCalls, "103")
|
||||
require.NotContains(t, backend.tokenCalls, "104")
|
||||
require.NotContains(t, backend.tokenCalls, "106")
|
||||
}
|
||||
|
||||
func TestLSPoolBootstrapServiceBootstrapContinuesOnWorkerFailure(t *testing.T) {
|
||||
reader := &fakeLSBootstrapAccountReader{
|
||||
accounts: []Account{
|
||||
{
|
||||
ID: 201,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"access_token": "token-201", "project_id": "proj-201"},
|
||||
},
|
||||
{
|
||||
ID: 202,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"access_token": "token-202", "project_id": "proj-202"},
|
||||
},
|
||||
},
|
||||
}
|
||||
backend := newFakeLSPoolBackend()
|
||||
backend.getErrs["201"] = errors.New("create failed")
|
||||
|
||||
svc := NewLSPoolBootstrapService(reader, backend, &config.Config{})
|
||||
svc.bootstrap(context.Background())
|
||||
|
||||
require.Len(t, backend.getCalls, 2)
|
||||
require.Contains(t, backend.tokenCalls, "201")
|
||||
require.Contains(t, backend.tokenCalls, "202")
|
||||
}
|
||||
@ -471,6 +471,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewCRSSyncService,
|
||||
ProvideUpdateService,
|
||||
ProvideTokenRefreshService,
|
||||
ProvideLSPoolBootstrapService,
|
||||
ProvideAccountExpiryService,
|
||||
ProvideSubscriptionExpiryService,
|
||||
ProvideTimingWheelService,
|
||||
|
||||
@ -2,6 +2,7 @@ package logredact
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
@ -230,3 +231,19 @@ func isSensitiveKey(key string, keys map[string]struct{}) bool {
|
||||
func normalizeKey(key string) string {
|
||||
return strings.ToLower(strings.TrimSpace(key))
|
||||
}
|
||||
|
||||
// RedactProxyURL strips userinfo (username:password) from a proxy URL string
|
||||
// for safe logging. Returns the input unchanged if it's not a valid URL.
|
||||
func RedactProxyURL(raw string) string {
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
parsed, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return "<redacted-proxy-url>"
|
||||
}
|
||||
if parsed.User != nil {
|
||||
parsed.User = nil
|
||||
}
|
||||
return parsed.String()
|
||||
}
|
||||
|
||||
@ -38,6 +38,34 @@ func TestRedactText_GOCSPX(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactProxyURL_StripsUserinfo(t *testing.T) {
|
||||
in := "http://user:pass@proxy.example.com:8080"
|
||||
out := RedactProxyURL(in)
|
||||
if out != "http://proxy.example.com:8080" {
|
||||
t.Fatalf("expected userinfo stripped, got %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactProxyURL_EmptyString(t *testing.T) {
|
||||
if got := RedactProxyURL(""); got != "" {
|
||||
t.Fatalf("expected empty string, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactProxyURL_NoUserinfo(t *testing.T) {
|
||||
in := "socks5h://proxy.example.com:1080"
|
||||
out := RedactProxyURL(in)
|
||||
if out != in {
|
||||
t.Fatalf("expected unchanged URL, got %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactProxyURL_InvalidURL(t *testing.T) {
|
||||
if got := RedactProxyURL("://invalid"); got != "<redacted-proxy-url>" {
|
||||
t.Fatalf("unexpected invalid URL redaction result: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactText_ExtraKeyCacheUsesNormalizedSortedKey(t *testing.T) {
|
||||
clearExtraTextPatternCache()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user