377 lines
13 KiB
Go
377 lines
13 KiB
Go
package lspool
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestBuildLSEnvKeepsExistingSSLValues(t *testing.T) {
|
|
env := buildLSEnv([]string{
|
|
"SSL_CERT_FILE=/custom/ca.pem",
|
|
"SSL_CERT_DIR=/custom/certs",
|
|
}, "/opt/antigravity", "")
|
|
require.Contains(t, env, "ANTIGRAVITY_EDITOR_APP_ROOT=/opt/antigravity")
|
|
require.Contains(t, env, "SSL_CERT_FILE=/custom/ca.pem")
|
|
require.Contains(t, env, "SSL_CERT_DIR=/custom/certs")
|
|
}
|
|
|
|
func TestBuildLSEnvClearsInheritedProxyWhenUnset(t *testing.T) {
|
|
env := buildLSEnv([]string{
|
|
"HTTPS_PROXY=http://old-proxy:8080",
|
|
"HTTP_PROXY=http://old-proxy:8080",
|
|
"ALL_PROXY=socks5://old-proxy:1080",
|
|
"https_proxy=http://old-proxy:8080",
|
|
"http_proxy=http://old-proxy:8080",
|
|
"all_proxy=socks5://old-proxy:1080",
|
|
}, "/opt/antigravity", "")
|
|
|
|
require.Contains(t, env, "HTTPS_PROXY=")
|
|
require.Contains(t, env, "HTTP_PROXY=")
|
|
require.Contains(t, env, "ALL_PROXY=")
|
|
require.Contains(t, env, "https_proxy=")
|
|
require.Contains(t, env, "http_proxy=")
|
|
require.Contains(t, env, "all_proxy=")
|
|
}
|
|
|
|
func TestShortAccountID(t *testing.T) {
|
|
require.Equal(t, "9", shortAccountID("9"))
|
|
require.Equal(t, "12345678", shortAccountID("12345678"))
|
|
require.Equal(t, "12345678", shortAccountID("123456789"))
|
|
}
|
|
|
|
func TestFrameConnectMessage(t *testing.T) {
|
|
framed := frameConnectMessage([]byte(`{"x":1}`))
|
|
require.Len(t, framed, 5+len(`{"x":1}`))
|
|
require.Equal(t, byte(0), framed[0])
|
|
require.Equal(t, uint32(len(`{"x":1}`)), binary.BigEndian.Uint32(framed[1:5]))
|
|
require.Equal(t, `{"x":1}`, string(framed[5:]))
|
|
}
|
|
|
|
func TestConnectEnvelope(t *testing.T) {
|
|
payload := []byte("hello")
|
|
env := connectEnvelope(0x00, payload)
|
|
require.Len(t, env, 5+len(payload))
|
|
require.Equal(t, byte(0x00), env[0])
|
|
require.Equal(t, uint32(5), binary.BigEndian.Uint32(env[1:5]))
|
|
require.Equal(t, "hello", string(env[5:]))
|
|
}
|
|
|
|
func TestUnwrapConnectEnvelope(t *testing.T) {
|
|
payload := []byte("test data")
|
|
env := connectEnvelope(0x00, payload)
|
|
unwrapped := unwrapConnectEnvelope(env)
|
|
require.Equal(t, payload, unwrapped)
|
|
short := []byte{1, 2}
|
|
require.Equal(t, short, unwrapConnectEnvelope(short))
|
|
}
|
|
|
|
func TestExtractPromptAndModel(t *testing.T) {
|
|
body := `{"model":"gemini-2.5-pro","project":"p","request":{"contents":[{"role":"user","parts":[{"text":"hello world"}]}]}}`
|
|
prompt, model := extractPromptAndModel([]byte(body))
|
|
require.Equal(t, "hello world", prompt)
|
|
require.Equal(t, "gemini-2.5-pro", model)
|
|
|
|
body2 := `{"contents":[{"role":"user","parts":[{"text":"test prompt"}]}]}`
|
|
prompt2, _ := extractPromptAndModel([]byte(body2))
|
|
require.Equal(t, "test prompt", prompt2)
|
|
}
|
|
|
|
func TestResolveModelEnum(t *testing.T) {
|
|
// Without dynamic mapping loaded, should return fallback (312 = gemini-2.5-flash)
|
|
require.True(t, resolveModelEnum("gemini-2.5-flash") > 0)
|
|
require.True(t, resolveModelEnum("models/gemini-2.5-flash") > 0)
|
|
require.True(t, resolveModelEnum("claude-sonnet-4-6") > 0)
|
|
require.True(t, resolveModelEnum("unknown-model") > 0)
|
|
}
|
|
|
|
func TestBuildCascadeConfigIncludesRequestedModel(t *testing.T) {
|
|
cfg := buildCascadeConfig("models/gemini-2.5-flash")
|
|
require.NotNil(t, cfg)
|
|
|
|
plannerConfig, ok := cfg["plannerConfig"].(map[string]any)
|
|
require.True(t, ok)
|
|
requestedModel, ok := plannerConfig["requestedModel"].(map[string]any)
|
|
require.True(t, ok)
|
|
require.NotEmpty(t, requestedModel["model"])
|
|
require.Len(t, plannerConfig, 1)
|
|
}
|
|
|
|
func TestBuildCascadeConfigClaudeIncludesRequestedModel(t *testing.T) {
|
|
cfg := buildCascadeConfig("claude-sonnet-4-6")
|
|
require.NotNil(t, cfg)
|
|
|
|
plannerConfig, ok := cfg["plannerConfig"].(map[string]any)
|
|
require.True(t, ok)
|
|
requestedModel, ok := plannerConfig["requestedModel"].(map[string]any)
|
|
require.True(t, ok)
|
|
require.NotEmpty(t, requestedModel["model"])
|
|
require.Len(t, plannerConfig, 1)
|
|
}
|
|
|
|
func TestDoNonStreamGeneratePassesThrough(t *testing.T) {
|
|
fallback := &recordingUpstream{}
|
|
upstream := NewLSPoolUpstream(&Pool{}, fallback)
|
|
req, _ := http.NewRequest("POST", "https://example.com/v1beta/models/gemini:generateContent", bytes.NewReader([]byte(`{}`)))
|
|
resp, err := upstream.Do(req, "", 1, 1)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, resp)
|
|
require.Equal(t, 1, fallback.doCalls)
|
|
}
|
|
|
|
func TestExtractPlannerResponseText(t *testing.T) {
|
|
resp := `{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[
|
|
{"type":"CORTEX_STEP_TYPE_USER_INPUT","status":"CORTEX_STEP_STATUS_DONE"},
|
|
{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE",
|
|
"plannerResponse":{"response":"Hello world"}}
|
|
]}}`
|
|
text, generating, status := extractPlannerResponseText([]byte(resp))
|
|
require.Equal(t, "Hello world", text)
|
|
require.False(t, generating)
|
|
require.Equal(t, "CASCADE_RUN_STATUS_IDLE", status)
|
|
}
|
|
|
|
func TestExtractPlannerResponseState_ErrorDetails(t *testing.T) {
|
|
resp := `{
|
|
"status":"CASCADE_RUN_STATUS_IDLE",
|
|
"trajectory":{
|
|
"steps":[
|
|
{"type":"CORTEX_STEP_TYPE_USER_INPUT","status":"CORTEX_STEP_STATUS_DONE"}
|
|
],
|
|
"executorMetadata":{
|
|
"terminationReason":"ERROR",
|
|
"errorDetails":{
|
|
"errorCode":429,
|
|
"shortError":"Model quota reached",
|
|
"details":"You have exhausted your capacity on this model. Your quota will reset after 1h59m40s."
|
|
}
|
|
}
|
|
}
|
|
}`
|
|
|
|
state := extractPlannerResponseState([]byte(resp))
|
|
require.Equal(t, "CASCADE_RUN_STATUS_IDLE", state.Status)
|
|
require.False(t, state.Generating)
|
|
require.Empty(t, state.Text)
|
|
require.Contains(t, state.ErrorMessage, "Model quota reached")
|
|
require.Contains(t, state.ErrorMessage, "quota will reset after")
|
|
}
|
|
|
|
func TestBuildGeminiSSEChunk(t *testing.T) {
|
|
sse := buildGeminiSSEChunk("hello")
|
|
require.Contains(t, sse, "data: ")
|
|
require.Contains(t, sse, `"text":"hello"`)
|
|
require.Contains(t, sse, `"role":"model"`)
|
|
require.True(t, strings.HasSuffix(sse, "\n\n"))
|
|
}
|
|
|
|
func TestRequestHasTools(t *testing.T) {
|
|
// Wrapped format with tools
|
|
require.True(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[],"tools":[{"functionDeclarations":[{"name":"get_weather"}]}]}}`)))
|
|
|
|
// Direct format with tools
|
|
require.True(t, requestHasTools([]byte(`{"contents":[],"tools":[{"functionDeclarations":[{"name":"f"}]}]}`)))
|
|
|
|
// No tools
|
|
require.False(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`)))
|
|
|
|
// Empty tools array
|
|
require.False(t, requestHasTools([]byte(`{"contents":[],"tools":[]}`)))
|
|
}
|
|
|
|
func TestCurrentLSStrategy(t *testing.T) {
|
|
t.Setenv("ANTIGRAVITY_LS_STRATEGY", "js-parity")
|
|
require.Equal(t, LSStrategyJSParity, CurrentLSStrategy())
|
|
|
|
t.Setenv("ANTIGRAVITY_LS_STRATEGY", "unknown")
|
|
require.Equal(t, LSStrategyDirect, CurrentLSStrategy())
|
|
}
|
|
|
|
func TestIsPermanentModelMappingError(t *testing.T) {
|
|
require.True(t, isPermanentModelMappingError(errors.New(`oauth2: "unauthorized_client" "Unauthorized"`)))
|
|
require.False(t, isPermanentModelMappingError(errors.New("context deadline exceeded")))
|
|
}
|
|
|
|
func TestPoolSetAccountTokenClearsModelMappingUnavailable(t *testing.T) {
|
|
pool := &Pool{
|
|
instances: map[string][]*Instance{
|
|
"9": {
|
|
{AccountID: "9", Replica: 0},
|
|
},
|
|
},
|
|
}
|
|
inst := pool.instances["9"][0]
|
|
inst.SetModelMappingReady(true)
|
|
inst.SetModelMappingUnavailable(`oauth2: "unauthorized_client" "Unauthorized"`)
|
|
|
|
pool.SetAccountToken("9", "ya29.new", "refresh", time.Now().Add(time.Hour))
|
|
|
|
require.False(t, inst.HasModelMappingReady())
|
|
require.False(t, inst.HasModelMappingUnavailable())
|
|
require.Empty(t, inst.ModelMappingUnavailableReason())
|
|
}
|
|
|
|
func TestShouldFallbackDirectForModelMappingUnavailable(t *testing.T) {
|
|
require.True(t, shouldFallbackDirect(fmt.Errorf("%w: oauth2 unauthorized_client", errLSModelMapDenied)))
|
|
require.False(t, shouldFallbackDirect(errLSModelMapPending))
|
|
}
|
|
|
|
func TestParseLSReplicaCountDefaultAndEnv(t *testing.T) {
|
|
t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "")
|
|
require.Equal(t, 5, parseLSReplicaCount())
|
|
|
|
t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "3")
|
|
require.Equal(t, 3, parseLSReplicaCount())
|
|
|
|
t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "0")
|
|
require.Equal(t, 5, parseLSReplicaCount())
|
|
}
|
|
|
|
func TestPoolGetUsesStickyReplicaSlot(t *testing.T) {
|
|
pool := &Pool{
|
|
config: Config{ReplicasPerAccount: 5},
|
|
instances: map[string][]*Instance{
|
|
"acc-1": {
|
|
{AccountID: "acc-1", Replica: 0, healthy: true},
|
|
{AccountID: "acc-1", Replica: 1, healthy: true},
|
|
{AccountID: "acc-1", Replica: 2, healthy: true},
|
|
{AccountID: "acc-1", Replica: 3, healthy: true},
|
|
{AccountID: "acc-1", Replica: 4, healthy: true},
|
|
},
|
|
},
|
|
}
|
|
|
|
routingKey := "acc-1:user-a:session-1"
|
|
slot := replicaSlotIndex(routingKey, pool.replicaCount())
|
|
inst := pool.Get("acc-1", routingKey)
|
|
require.NotNil(t, inst)
|
|
require.Equal(t, slot, inst.Replica)
|
|
}
|
|
|
|
func TestPoolGetWithoutRoutingKeyPrefersLeastBusyReplica(t *testing.T) {
|
|
busy := &Instance{AccountID: "acc-1", Replica: 0, healthy: true}
|
|
atomic.StoreInt64(&busy.inflight, 4)
|
|
idle := &Instance{AccountID: "acc-1", Replica: 1, healthy: true}
|
|
atomic.StoreInt64(&idle.inflight, 1)
|
|
|
|
pool := &Pool{
|
|
config: Config{ReplicasPerAccount: 5},
|
|
instances: map[string][]*Instance{
|
|
"acc-1": {busy, idle},
|
|
},
|
|
}
|
|
|
|
inst := pool.Get("acc-1", "")
|
|
require.NotNil(t, inst)
|
|
require.Equal(t, 1, inst.Replica)
|
|
}
|
|
|
|
func TestWaitForInstanceReadyProbesImmediately(t *testing.T) {
|
|
startedAt := time.Now()
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
|
|
attempts, err := waitForInstanceReady(ctx, 200*time.Millisecond, func(context.Context) error {
|
|
return nil
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, 1, attempts)
|
|
require.Less(t, time.Since(startedAt), 100*time.Millisecond)
|
|
}
|
|
|
|
func TestWaitForInstanceReadyRetriesUntilSuccess(t *testing.T) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
|
|
calls := 0
|
|
attempts, err := waitForInstanceReady(ctx, 10*time.Millisecond, func(context.Context) error {
|
|
calls++
|
|
if calls < 3 {
|
|
return errors.New("not ready")
|
|
}
|
|
return nil
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, 3, attempts)
|
|
require.Equal(t, 3, calls)
|
|
}
|
|
|
|
func TestDecideJSParityRoute(t *testing.T) {
|
|
body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"s1","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
|
|
parsed, err := parseGeminiRequest(body)
|
|
require.NoError(t, err)
|
|
decision := decideJSParityRoute(parsed, body)
|
|
require.True(t, decision.UseLS)
|
|
|
|
imageBody := []byte(`{"model":"gemini-2.5-flash-image","request":{"sessionId":"s1","contents":[{"role":"user","parts":[{"text":"draw"}]}],"generationConfig":{"responseModalities":["TEXT","IMAGE"]}}}`)
|
|
parsedImage, err := parseGeminiRequest(imageBody)
|
|
require.NoError(t, err)
|
|
decisionImage := decideJSParityRoute(parsedImage, imageBody)
|
|
require.False(t, decisionImage.UseLS)
|
|
require.Contains(t, strings.ToLower(decisionImage.Reason), "image")
|
|
|
|
noSessionBody := []byte(`{"model":"gemini-2.5-flash","request":{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
|
|
parsedNoSession, err := parseGeminiRequest(noSessionBody)
|
|
require.NoError(t, err)
|
|
require.False(t, decideJSParityRoute(parsedNoSession, noSessionBody).UseLS)
|
|
}
|
|
|
|
func TestUserNamespacePrefersExplicitHeader(t *testing.T) {
|
|
req, err := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
|
require.NoError(t, err)
|
|
req.Header.Set(userNamespaceHeader, "tenant-a")
|
|
req.Header.Set("Authorization", "Bearer oauth-token")
|
|
|
|
nsWithExplicit := userNamespace(req)
|
|
require.NotEqual(t, "anon", nsWithExplicit)
|
|
|
|
req.Header.Del(userNamespaceHeader)
|
|
nsWithAuth := userNamespace(req)
|
|
require.NotEqual(t, "anon", nsWithAuth)
|
|
require.NotEqual(t, nsWithExplicit, nsWithAuth)
|
|
}
|
|
|
|
func TestConversationPrefixEqual(t *testing.T) {
|
|
prefix := []geminiConversationTurn{
|
|
{Role: "user", Parts: []geminiConversationPart{{Kind: "text", Text: "hello"}}},
|
|
{Role: "model", Parts: []geminiConversationPart{{Kind: "text", Text: "world"}}},
|
|
}
|
|
full := append(cloneConversationTurns(prefix), geminiConversationTurn{
|
|
Role: "user",
|
|
Parts: []geminiConversationPart{{Kind: "text", Text: "follow up"}},
|
|
})
|
|
require.True(t, conversationPrefixEqual(full, prefix))
|
|
require.False(t, conversationPrefixEqual(prefix, full))
|
|
}
|
|
|
|
func TestSystemTextCompatible(t *testing.T) {
|
|
require.True(t, systemTextCompatible("You are helpful", ""))
|
|
require.True(t, systemTextCompatible("You are helpful", "You are helpful"))
|
|
require.False(t, systemTextCompatible("", "You are helpful"))
|
|
require.False(t, systemTextCompatible("You are helpful", "You are different"))
|
|
}
|
|
|
|
type recordingUpstream struct {
|
|
doCalls int
|
|
}
|
|
|
|
func (r *recordingUpstream) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
|
r.doCalls++
|
|
return &http.Response{StatusCode: 200, Body: io.NopCloser(bytes.NewBufferString("ok")), Header: make(http.Header), Request: req}, nil
|
|
}
|
|
|
|
func (r *recordingUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, c int, _ *tlsfingerprint.Profile) (*http.Response, error) {
|
|
return r.Do(req, proxyURL, accountID, c)
|
|
}
|