win b856586412
Some checks failed
CI / test (push) Failing after 16m30s
CI / golangci-lint (push) Failing after 4s
Security Scan / backend-security (push) Failing after 1m35s
Security Scan / frontend-security (push) Failing after 1m31s
修复h1
2026-04-01 01:35:49 +08:00

377 lines
13 KiB
Go

package lspool
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net/http"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/stretchr/testify/require"
)
func TestBuildLSEnvKeepsExistingSSLValues(t *testing.T) {
env := buildLSEnv([]string{
"SSL_CERT_FILE=/custom/ca.pem",
"SSL_CERT_DIR=/custom/certs",
}, "/opt/antigravity", "")
require.Contains(t, env, "ANTIGRAVITY_EDITOR_APP_ROOT=/opt/antigravity")
require.Contains(t, env, "SSL_CERT_FILE=/custom/ca.pem")
require.Contains(t, env, "SSL_CERT_DIR=/custom/certs")
}
func TestBuildLSEnvClearsInheritedProxyWhenUnset(t *testing.T) {
env := buildLSEnv([]string{
"HTTPS_PROXY=http://old-proxy:8080",
"HTTP_PROXY=http://old-proxy:8080",
"ALL_PROXY=socks5://old-proxy:1080",
"https_proxy=http://old-proxy:8080",
"http_proxy=http://old-proxy:8080",
"all_proxy=socks5://old-proxy:1080",
}, "/opt/antigravity", "")
require.Contains(t, env, "HTTPS_PROXY=")
require.Contains(t, env, "HTTP_PROXY=")
require.Contains(t, env, "ALL_PROXY=")
require.Contains(t, env, "https_proxy=")
require.Contains(t, env, "http_proxy=")
require.Contains(t, env, "all_proxy=")
}
func TestShortAccountID(t *testing.T) {
require.Equal(t, "9", shortAccountID("9"))
require.Equal(t, "12345678", shortAccountID("12345678"))
require.Equal(t, "12345678", shortAccountID("123456789"))
}
func TestFrameConnectMessage(t *testing.T) {
framed := frameConnectMessage([]byte(`{"x":1}`))
require.Len(t, framed, 5+len(`{"x":1}`))
require.Equal(t, byte(0), framed[0])
require.Equal(t, uint32(len(`{"x":1}`)), binary.BigEndian.Uint32(framed[1:5]))
require.Equal(t, `{"x":1}`, string(framed[5:]))
}
func TestConnectEnvelope(t *testing.T) {
payload := []byte("hello")
env := connectEnvelope(0x00, payload)
require.Len(t, env, 5+len(payload))
require.Equal(t, byte(0x00), env[0])
require.Equal(t, uint32(5), binary.BigEndian.Uint32(env[1:5]))
require.Equal(t, "hello", string(env[5:]))
}
func TestUnwrapConnectEnvelope(t *testing.T) {
payload := []byte("test data")
env := connectEnvelope(0x00, payload)
unwrapped := unwrapConnectEnvelope(env)
require.Equal(t, payload, unwrapped)
short := []byte{1, 2}
require.Equal(t, short, unwrapConnectEnvelope(short))
}
func TestExtractPromptAndModel(t *testing.T) {
body := `{"model":"gemini-2.5-pro","project":"p","request":{"contents":[{"role":"user","parts":[{"text":"hello world"}]}]}}`
prompt, model := extractPromptAndModel([]byte(body))
require.Equal(t, "hello world", prompt)
require.Equal(t, "gemini-2.5-pro", model)
body2 := `{"contents":[{"role":"user","parts":[{"text":"test prompt"}]}]}`
prompt2, _ := extractPromptAndModel([]byte(body2))
require.Equal(t, "test prompt", prompt2)
}
func TestResolveModelEnum(t *testing.T) {
// Without dynamic mapping loaded, should return fallback (312 = gemini-2.5-flash)
require.True(t, resolveModelEnum("gemini-2.5-flash") > 0)
require.True(t, resolveModelEnum("models/gemini-2.5-flash") > 0)
require.True(t, resolveModelEnum("claude-sonnet-4-6") > 0)
require.True(t, resolveModelEnum("unknown-model") > 0)
}
func TestBuildCascadeConfigIncludesRequestedModel(t *testing.T) {
cfg := buildCascadeConfig("models/gemini-2.5-flash")
require.NotNil(t, cfg)
plannerConfig, ok := cfg["plannerConfig"].(map[string]any)
require.True(t, ok)
requestedModel, ok := plannerConfig["requestedModel"].(map[string]any)
require.True(t, ok)
require.NotEmpty(t, requestedModel["model"])
require.Len(t, plannerConfig, 1)
}
func TestBuildCascadeConfigClaudeIncludesRequestedModel(t *testing.T) {
cfg := buildCascadeConfig("claude-sonnet-4-6")
require.NotNil(t, cfg)
plannerConfig, ok := cfg["plannerConfig"].(map[string]any)
require.True(t, ok)
requestedModel, ok := plannerConfig["requestedModel"].(map[string]any)
require.True(t, ok)
require.NotEmpty(t, requestedModel["model"])
require.Len(t, plannerConfig, 1)
}
func TestDoNonStreamGeneratePassesThrough(t *testing.T) {
fallback := &recordingUpstream{}
upstream := NewLSPoolUpstream(&Pool{}, fallback)
req, _ := http.NewRequest("POST", "https://example.com/v1beta/models/gemini:generateContent", bytes.NewReader([]byte(`{}`)))
resp, err := upstream.Do(req, "", 1, 1)
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, 1, fallback.doCalls)
}
func TestExtractPlannerResponseText(t *testing.T) {
resp := `{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[
{"type":"CORTEX_STEP_TYPE_USER_INPUT","status":"CORTEX_STEP_STATUS_DONE"},
{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE",
"plannerResponse":{"response":"Hello world"}}
]}}`
text, generating, status := extractPlannerResponseText([]byte(resp))
require.Equal(t, "Hello world", text)
require.False(t, generating)
require.Equal(t, "CASCADE_RUN_STATUS_IDLE", status)
}
func TestExtractPlannerResponseState_ErrorDetails(t *testing.T) {
resp := `{
"status":"CASCADE_RUN_STATUS_IDLE",
"trajectory":{
"steps":[
{"type":"CORTEX_STEP_TYPE_USER_INPUT","status":"CORTEX_STEP_STATUS_DONE"}
],
"executorMetadata":{
"terminationReason":"ERROR",
"errorDetails":{
"errorCode":429,
"shortError":"Model quota reached",
"details":"You have exhausted your capacity on this model. Your quota will reset after 1h59m40s."
}
}
}
}`
state := extractPlannerResponseState([]byte(resp))
require.Equal(t, "CASCADE_RUN_STATUS_IDLE", state.Status)
require.False(t, state.Generating)
require.Empty(t, state.Text)
require.Contains(t, state.ErrorMessage, "Model quota reached")
require.Contains(t, state.ErrorMessage, "quota will reset after")
}
func TestBuildGeminiSSEChunk(t *testing.T) {
sse := buildGeminiSSEChunk("hello")
require.Contains(t, sse, "data: ")
require.Contains(t, sse, `"text":"hello"`)
require.Contains(t, sse, `"role":"model"`)
require.True(t, strings.HasSuffix(sse, "\n\n"))
}
func TestRequestHasTools(t *testing.T) {
// Wrapped format with tools
require.True(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[],"tools":[{"functionDeclarations":[{"name":"get_weather"}]}]}}`)))
// Direct format with tools
require.True(t, requestHasTools([]byte(`{"contents":[],"tools":[{"functionDeclarations":[{"name":"f"}]}]}`)))
// No tools
require.False(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`)))
// Empty tools array
require.False(t, requestHasTools([]byte(`{"contents":[],"tools":[]}`)))
}
func TestCurrentLSStrategy(t *testing.T) {
t.Setenv("ANTIGRAVITY_LS_STRATEGY", "js-parity")
require.Equal(t, LSStrategyJSParity, CurrentLSStrategy())
t.Setenv("ANTIGRAVITY_LS_STRATEGY", "unknown")
require.Equal(t, LSStrategyDirect, CurrentLSStrategy())
}
func TestIsPermanentModelMappingError(t *testing.T) {
require.True(t, isPermanentModelMappingError(errors.New(`oauth2: "unauthorized_client" "Unauthorized"`)))
require.False(t, isPermanentModelMappingError(errors.New("context deadline exceeded")))
}
func TestPoolSetAccountTokenClearsModelMappingUnavailable(t *testing.T) {
pool := &Pool{
instances: map[string][]*Instance{
"9": {
{AccountID: "9", Replica: 0},
},
},
}
inst := pool.instances["9"][0]
inst.SetModelMappingReady(true)
inst.SetModelMappingUnavailable(`oauth2: "unauthorized_client" "Unauthorized"`)
pool.SetAccountToken("9", "ya29.new", "refresh", time.Now().Add(time.Hour))
require.False(t, inst.HasModelMappingReady())
require.False(t, inst.HasModelMappingUnavailable())
require.Empty(t, inst.ModelMappingUnavailableReason())
}
func TestShouldFallbackDirectForModelMappingUnavailable(t *testing.T) {
require.True(t, shouldFallbackDirect(fmt.Errorf("%w: oauth2 unauthorized_client", errLSModelMapDenied)))
require.False(t, shouldFallbackDirect(errLSModelMapPending))
}
func TestParseLSReplicaCountDefaultAndEnv(t *testing.T) {
t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "")
require.Equal(t, 5, parseLSReplicaCount())
t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "3")
require.Equal(t, 3, parseLSReplicaCount())
t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "0")
require.Equal(t, 5, parseLSReplicaCount())
}
func TestPoolGetUsesStickyReplicaSlot(t *testing.T) {
pool := &Pool{
config: Config{ReplicasPerAccount: 5},
instances: map[string][]*Instance{
"acc-1": {
{AccountID: "acc-1", Replica: 0, healthy: true},
{AccountID: "acc-1", Replica: 1, healthy: true},
{AccountID: "acc-1", Replica: 2, healthy: true},
{AccountID: "acc-1", Replica: 3, healthy: true},
{AccountID: "acc-1", Replica: 4, healthy: true},
},
},
}
routingKey := "acc-1:user-a:session-1"
slot := replicaSlotIndex(routingKey, pool.replicaCount())
inst := pool.Get("acc-1", routingKey)
require.NotNil(t, inst)
require.Equal(t, slot, inst.Replica)
}
func TestPoolGetWithoutRoutingKeyPrefersLeastBusyReplica(t *testing.T) {
busy := &Instance{AccountID: "acc-1", Replica: 0, healthy: true}
atomic.StoreInt64(&busy.inflight, 4)
idle := &Instance{AccountID: "acc-1", Replica: 1, healthy: true}
atomic.StoreInt64(&idle.inflight, 1)
pool := &Pool{
config: Config{ReplicasPerAccount: 5},
instances: map[string][]*Instance{
"acc-1": {busy, idle},
},
}
inst := pool.Get("acc-1", "")
require.NotNil(t, inst)
require.Equal(t, 1, inst.Replica)
}
func TestWaitForInstanceReadyProbesImmediately(t *testing.T) {
startedAt := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
attempts, err := waitForInstanceReady(ctx, 200*time.Millisecond, func(context.Context) error {
return nil
})
require.NoError(t, err)
require.Equal(t, 1, attempts)
require.Less(t, time.Since(startedAt), 100*time.Millisecond)
}
func TestWaitForInstanceReadyRetriesUntilSuccess(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
calls := 0
attempts, err := waitForInstanceReady(ctx, 10*time.Millisecond, func(context.Context) error {
calls++
if calls < 3 {
return errors.New("not ready")
}
return nil
})
require.NoError(t, err)
require.Equal(t, 3, attempts)
require.Equal(t, 3, calls)
}
func TestDecideJSParityRoute(t *testing.T) {
body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"s1","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
parsed, err := parseGeminiRequest(body)
require.NoError(t, err)
decision := decideJSParityRoute(parsed, body)
require.True(t, decision.UseLS)
imageBody := []byte(`{"model":"gemini-2.5-flash-image","request":{"sessionId":"s1","contents":[{"role":"user","parts":[{"text":"draw"}]}],"generationConfig":{"responseModalities":["TEXT","IMAGE"]}}}`)
parsedImage, err := parseGeminiRequest(imageBody)
require.NoError(t, err)
decisionImage := decideJSParityRoute(parsedImage, imageBody)
require.False(t, decisionImage.UseLS)
require.Contains(t, strings.ToLower(decisionImage.Reason), "image")
noSessionBody := []byte(`{"model":"gemini-2.5-flash","request":{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
parsedNoSession, err := parseGeminiRequest(noSessionBody)
require.NoError(t, err)
require.False(t, decideJSParityRoute(parsedNoSession, noSessionBody).UseLS)
}
func TestUserNamespacePrefersExplicitHeader(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "https://example.com", nil)
require.NoError(t, err)
req.Header.Set(userNamespaceHeader, "tenant-a")
req.Header.Set("Authorization", "Bearer oauth-token")
nsWithExplicit := userNamespace(req)
require.NotEqual(t, "anon", nsWithExplicit)
req.Header.Del(userNamespaceHeader)
nsWithAuth := userNamespace(req)
require.NotEqual(t, "anon", nsWithAuth)
require.NotEqual(t, nsWithExplicit, nsWithAuth)
}
func TestConversationPrefixEqual(t *testing.T) {
prefix := []geminiConversationTurn{
{Role: "user", Parts: []geminiConversationPart{{Kind: "text", Text: "hello"}}},
{Role: "model", Parts: []geminiConversationPart{{Kind: "text", Text: "world"}}},
}
full := append(cloneConversationTurns(prefix), geminiConversationTurn{
Role: "user",
Parts: []geminiConversationPart{{Kind: "text", Text: "follow up"}},
})
require.True(t, conversationPrefixEqual(full, prefix))
require.False(t, conversationPrefixEqual(prefix, full))
}
func TestSystemTextCompatible(t *testing.T) {
require.True(t, systemTextCompatible("You are helpful", ""))
require.True(t, systemTextCompatible("You are helpful", "You are helpful"))
require.False(t, systemTextCompatible("", "You are helpful"))
require.False(t, systemTextCompatible("You are helpful", "You are different"))
}
type recordingUpstream struct {
doCalls int
}
func (r *recordingUpstream) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
r.doCalls++
return &http.Response{StatusCode: 200, Body: io.NopCloser(bytes.NewBufferString("ok")), Header: make(http.Header), Request: req}, nil
}
func (r *recordingUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, c int, _ *tlsfingerprint.Profile) (*http.Response, error) {
return r.Do(req, proxyURL, accountID, c)
}