sub2api/backend/internal/pkg/lspool/upstream_adapter.go
win 0e9f780815
Some checks failed
CI / test (push) Failing after 3s
CI / golangci-lint (push) Failing after 3s
Security Scan / backend-security (push) Failing after 4s
Security Scan / frontend-security (push) Failing after 3s
fix: surface ls quota exhaustion in antigravity streams
2026-03-31 01:26:48 +08:00

1646 lines
48 KiB
Go

// Package lspool provides an HTTPUpstream adapter that routes
// streamGenerateContent requests through real Language Server instances.
//
// Flow:
//
// sub2api → LSPoolUpstream.Do() → StartCascade → SendUserCascadeMessage
// → LS internally calls cloudcode-pa (with authentic TLS fingerprint)
// → Poll GetCascadeTrajectory for incremental text
// → Format as SSE and stream back to sub2api service layer
//
// The model is extracted from the original request body, not hardcoded.
package lspool
import (
"bytes"
"context"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
)
// Upstream is the interface matching service.HTTPUpstream
type Upstream interface {
Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error)
DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error)
}
// LSPoolUpstream wraps an existing HTTPUpstream and intercepts
// streamGenerateContent requests to route them through the LS pool.
type LSPoolUpstream struct {
pool Backend
fallback Upstream
logger *slog.Logger
sessionMu sync.Mutex
sessions map[string]*cascadeSessionState
}
// NewLSPoolUpstream creates an LS pool upstream wrapper.
func NewLSPoolUpstream(pool Backend, fallback Upstream) *LSPoolUpstream {
return &LSPoolUpstream{
pool: pool,
fallback: fallback,
logger: slog.Default().With("component", "lspool-upstream"),
sessions: make(map[string]*cascadeSessionState),
}
}
const (
userNamespaceHeader = "X-Sub2API-User-Key"
useAICreditsHeader = "X-Antigravity-Use-AI-Credits"
availableCreditsHeader = "X-Antigravity-Available-Credits"
minimumCreditAmountHeader = "X-Antigravity-Minimum-Credit-Amount"
sessionStateTTL = 30 * time.Minute
lsSendMessageTimeout = 20 * time.Second
lsModelConfigTimeout = 20 * time.Second
)
var (
errLSRouteDirect = errors.New("request should use direct upstream")
errLSTranscriptDrift = errors.New("request transcript diverged from cached cascade session")
errLSQuotaExhausted = errors.New("ls cascade returned quota exhausted")
errLSModelMapPending = errors.New("model mapping not ready")
)
// IsLSQuotaExhaustedError reports whether err originated from an LS cascade
// quota/capacity exhaustion signal.
func IsLSQuotaExhaustedError(err error) bool {
return errors.Is(err, errLSQuotaExhausted)
}
// LSQuotaExhaustedMessage extracts the original LS error message, if present.
func LSQuotaExhaustedMessage(err error) string {
if err == nil {
return ""
}
msg := strings.TrimSpace(err.Error())
if msg == "" {
return ""
}
prefix := errLSQuotaExhausted.Error()
if msg == prefix {
return ""
}
if strings.HasPrefix(msg, prefix+":") {
return strings.TrimSpace(strings.TrimPrefix(msg, prefix+":"))
}
return msg
}
type cascadeSessionState struct {
CascadeID string
SystemText string
History []geminiConversationTurn
UpdatedAt time.Time
}
type geminiEnvelope struct {
Model string `json:"model"`
Request json.RawMessage `json:"request"`
}
type geminiRequestPayload struct {
Contents []geminiWireContent `json:"contents"`
SystemInstruction *geminiWireContent `json:"systemInstruction,omitempty"`
GenerationConfig *geminiWireGenerationConfig `json:"generationConfig,omitempty"`
SessionID string `json:"sessionId,omitempty"`
}
type geminiWireGenerationConfig struct {
ResponseModalities []string `json:"responseModalities,omitempty"`
ImageConfig json.RawMessage `json:"imageConfig,omitempty"`
}
type geminiWireContent struct {
Role string `json:"role"`
Parts []geminiWirePart `json:"parts"`
}
type geminiWirePart struct {
Text string `json:"text,omitempty"`
Thought bool `json:"thought,omitempty"`
ThoughtSignature string `json:"thoughtSignature,omitempty"`
InlineData *geminiWireInlineData `json:"inlineData,omitempty"`
FunctionCall map[string]any `json:"functionCall,omitempty"`
FunctionResponse map[string]any `json:"functionResponse,omitempty"`
}
type geminiWireInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
}
type geminiParsedRequest struct {
Model string
SessionID string
SystemText string
Turns []geminiConversationTurn
ResponseModalities []string
HasImageConfig bool
HasUnsupported bool
}
type geminiConversationTurn struct {
Role string
Parts []geminiConversationPart
}
type geminiConversationPart struct {
Kind string
Text string
MimeType string
Data string
}
type lsRouteDecision struct {
UseLS bool
Reason string
}
type lsRequestTrace struct {
StartedAt time.Time
AccountID int64
Model string
SessionIDHash string
Replica int
CascadeID string
NewSession bool
InflightAtAcquire int64
TurnCount int
GetOrCreateDuration time.Duration
StartCascadeDuration time.Duration
BuildInputDuration time.Duration
SendMessageDuration time.Duration
FirstPollLatency time.Duration
FirstTextLatency time.Duration
PollCount int
}
// Do routes streamGenerateContent through LS, everything else through fallback.
func (u *LSPoolUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
u.extractAndStripInternalHeaders(req, strconv.FormatInt(accountID, 10))
if !isStreamGenerate(req.URL.Path) {
return u.fallback.Do(req, proxyURL, accountID, accountConcurrency)
}
body, err := snapshotRequestBody(req)
if err != nil {
return nil, fmt.Errorf("snapshot request body: %w", err)
}
if len(bytes.TrimSpace(body)) == 0 {
return u.fallback.Do(req, proxyURL, accountID, accountConcurrency)
}
resp, err := u.doViaLS(req, body, accountID, proxyURL)
if err != nil {
if shouldFallbackDirect(err) {
u.logger.Warn("[LS-POOL] LS fell back to direct", "account", accountID, "err", err)
req.Body = io.NopCloser(bytes.NewReader(body))
return u.fallback.Do(req, proxyURL, accountID, accountConcurrency)
}
return nil, err
}
return resp, nil
}
// DoWithTLS — LS handles its own TLS, so profile is ignored for LS requests.
func (u *LSPoolUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) {
u.extractAndStripInternalHeaders(req, strconv.FormatInt(accountID, 10))
if !isStreamGenerate(req.URL.Path) {
return u.fallback.DoWithTLS(req, proxyURL, accountID, accountConcurrency, profile)
}
body, err := snapshotRequestBody(req)
if err != nil {
return nil, fmt.Errorf("snapshot request body: %w", err)
}
if len(bytes.TrimSpace(body)) == 0 {
return u.fallback.DoWithTLS(req, proxyURL, accountID, accountConcurrency, profile)
}
resp, err := u.doViaLS(req, body, accountID, proxyURL)
if err != nil {
if shouldFallbackDirect(err) {
u.logger.Warn("[LS-POOL] LS fell back to direct+TLS", "account", accountID, "err", err)
req.Body = io.NopCloser(bytes.NewReader(body))
return u.fallback.DoWithTLS(req, proxyURL, accountID, accountConcurrency, profile)
}
return nil, err
}
return resp, nil
}
func (u *LSPoolUpstream) doViaLS(req *http.Request, body []byte, accountID int64, proxyURL string) (*http.Response, error) {
accountKey := strconv.FormatInt(accountID, 10)
if CurrentLSStrategy() != LSStrategyJSParity {
return u.forwardDirectWithKeepalive(req, body, accountKey, accountID, proxyURL)
}
parsed, err := parseGeminiRequest(body)
if err != nil {
return u.forwardDirect(req, body, proxyURL, accountID, "parse request failed")
}
decision := decideJSParityRoute(parsed, body)
if !decision.UseLS {
return u.forwardDirect(req, body, proxyURL, accountID, decision.Reason)
}
resp, err := u.forwardChatViaLS(req, body, parsed, accountKey, accountID, proxyURL)
if err != nil {
if shouldFallbackDirect(err) {
return u.forwardDirect(req, body, proxyURL, accountID, err.Error())
}
return nil, err
}
return resp, nil
}
func shouldFallbackDirect(err error) bool {
return errors.Is(err, errLSRouteDirect) || errors.Is(err, errLSTranscriptDrift)
}
func (u *LSPoolUpstream) forwardDirectWithKeepalive(req *http.Request, body []byte, accountKey string, accountID int64, proxyURL string) (*http.Response, error) {
// Start/reuse LS instance — keeps heartbeat alive, authenticates with
// cloudcode-pa, and refreshes model mapping. The LS process itself is NOT
// used as a proxy; we forward the original HTTP request directly to
// cloudcode-pa, bypassing Cascade entirely. This avoids the IDE agent
// system prompt that Cascade injects.
_, err := u.pool.GetOrCreate(accountKey, "", proxyURL)
if err != nil {
return nil, fmt.Errorf("get LS instance: %w", err)
}
u.logger.Info("[LS-POOL] Forwarding via direct HTTP (LS keepalive active)",
"account", accountID, "path", req.URL.Path)
return u.forwardDirect(req, body, proxyURL, accountID, "strategy=direct")
}
func (u *LSPoolUpstream) forwardDirect(req *http.Request, body []byte, proxyURL string, accountID int64, reason string) (*http.Response, error) {
u.logger.Info("[LS-POOL] Forwarding via direct HTTP",
"account", accountID,
"path", req.URL.Path,
"reason", reason)
req.Header.Del(userNamespaceHeader)
req.Body = io.NopCloser(bytes.NewReader(body))
return u.fallback.Do(req, proxyURL, accountID, 1)
}
func (u *LSPoolUpstream) extractAndStripInternalHeaders(req *http.Request, accountKey string) {
if auth := req.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") {
accessToken := strings.TrimPrefix(auth, "Bearer ")
refreshToken := req.Header.Get("X-Antigravity-Refresh-Token")
var expiresAt time.Time
if raw := req.Header.Get("X-Antigravity-Token-Expiry"); raw != "" {
if parsed, err := time.Parse(time.RFC3339, raw); err == nil {
expiresAt = parsed
}
}
u.pool.SetAccountToken(accountKey, accessToken, refreshToken, expiresAt)
}
useAICredits, hasUseAICredits := parseBoolHeader(req.Header.Get(useAICreditsHeader))
availableCredits, hasAvailableCredits := parseOptionalInt32Header(req.Header.Get(availableCreditsHeader))
minimumCreditAmount, hasMinimumCreditAmount := parseOptionalInt32Header(req.Header.Get(minimumCreditAmountHeader))
if hasUseAICredits || hasAvailableCredits || hasMinimumCreditAmount {
u.pool.SetAccountModelCredits(accountKey, useAICredits, availableCredits, minimumCreditAmount)
}
req.Header.Del("X-Antigravity-Refresh-Token")
req.Header.Del("X-Antigravity-Token-Expiry")
req.Header.Del(useAICreditsHeader)
req.Header.Del(availableCreditsHeader)
req.Header.Del(minimumCreditAmountHeader)
}
func parseBoolHeader(raw string) (bool, bool) {
raw = strings.TrimSpace(raw)
if raw == "" {
return false, false
}
val, err := strconv.ParseBool(raw)
if err != nil {
return false, false
}
return val, true
}
func parseOptionalInt32Header(raw string) (*int32, bool) {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil, false
}
val, err := strconv.ParseInt(raw, 10, 32)
if err != nil {
return nil, false
}
parsed := int32(val)
return &parsed, true
}
func shortTraceID(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return "none"
}
sum := sha256.Sum256([]byte(raw))
return fmt.Sprintf("%x", sum[:4])
}
func durationMS(d time.Duration) int64 {
if d <= 0 {
return 0
}
return d.Milliseconds()
}
func (u *LSPoolUpstream) logTraceSummary(level slog.Level, msg string, trace *lsRequestTrace, extra ...any) {
if trace == nil {
u.logger.Log(context.Background(), level, msg, extra...)
return
}
args := []any{
"account", trace.AccountID,
"model", trace.Model,
"session", trace.SessionIDHash,
"replica", trace.Replica,
"cascade", shortTraceID(trace.CascadeID),
"new_session", trace.NewSession,
"turns", trace.TurnCount,
"inflight", trace.InflightAtAcquire,
"get_or_create_ms", durationMS(trace.GetOrCreateDuration),
"start_cascade_ms", durationMS(trace.StartCascadeDuration),
"build_input_ms", durationMS(trace.BuildInputDuration),
"send_message_ms", durationMS(trace.SendMessageDuration),
"first_poll_ms", durationMS(trace.FirstPollLatency),
"first_token_ms", durationMS(trace.FirstTextLatency),
"polls", trace.PollCount,
"total_ms", durationMS(time.Since(trace.StartedAt)),
}
args = append(args, extra...)
u.logger.Log(context.Background(), level, msg, args...)
}
func (u *LSPoolUpstream) forwardChatViaLS(req *http.Request, body []byte, parsed *geminiParsedRequest, accountKey string, accountID int64, proxyURL string) (*http.Response, error) {
trace := &lsRequestTrace{
StartedAt: time.Now(),
AccountID: accountID,
Model: parsed.Model,
SessionIDHash: shortTraceID(parsed.SessionID),
TurnCount: len(parsed.Turns),
}
getOrCreateStartedAt := time.Now()
sessionKey := buildSessionCacheKey(accountID, userNamespace(req), parsed.SessionID)
inst, err := u.pool.GetOrCreate(accountKey, sessionKey, proxyURL)
if err != nil {
trace.GetOrCreateDuration = time.Since(getOrCreateStartedAt)
u.logTraceSummary(slog.LevelWarn, "[LS-POOL] get instance failed", trace, "err", err)
return nil, fmt.Errorf("get LS instance: %w", err)
}
trace.GetOrCreateDuration = time.Since(getOrCreateStartedAt)
trace.Replica = inst.Replica
if !inst.HasModelMappingReady() {
u.logTraceSummary(slog.LevelInfo, "[LS-POOL] model mapping pending, routing direct", trace)
return nil, errLSModelMapPending
}
if !inst.AcquireConcurrency() {
u.logTraceSummary(slog.LevelWarn, "[LS-POOL] instance busy", trace,
"err", fmt.Sprintf("ls instance busy for account %d", accountID),
"current_inflight", inst.ConcurrentCount(),
"max_inflight", maxConcurrencyPerInstance)
return nil, fmt.Errorf("ls instance busy for account %d", accountID)
}
trace.InflightAtAcquire = inst.ConcurrentCount()
state := u.getSessionState(sessionKey)
if state != nil && !systemTextCompatible(state.SystemText, parsed.SystemText) {
inst.ReleaseConcurrency()
u.logTraceSummary(slog.LevelInfo, "[LS-POOL] transcript drift, routing direct", trace)
return nil, errLSTranscriptDrift
}
cascadeID := ""
newSession := false
sendTurn := geminiConversationTurn{}
contextPrefix := ""
switch {
case state == nil:
if len(parsed.Turns) == 0 {
inst.ReleaseConcurrency()
return nil, errLSRouteDirect
}
lastTurn := parsed.Turns[len(parsed.Turns)-1]
if lastTurn.Role != "user" {
inst.ReleaseConcurrency()
u.logTraceSummary(slog.LevelInfo, "[LS-POOL] invalid first turn for LS, routing direct", trace)
return nil, errLSRouteDirect
}
sendTurn = lastTurn
contextPrefix = renderConversationContext(parsed.SystemText, parsed.Turns[:len(parsed.Turns)-1])
startCascadeStartedAt := time.Now()
cascadeID, err = u.startCascade(inst)
trace.StartCascadeDuration = time.Since(startCascadeStartedAt)
if err != nil {
inst.ReleaseConcurrency()
u.logTraceSummary(slog.LevelWarn, "[LS-POOL] start cascade failed", trace, "err", err)
return nil, err
}
newSession = true
case !conversationPrefixEqual(parsed.Turns, state.History):
inst.ReleaseConcurrency()
u.logTraceSummary(slog.LevelInfo, "[LS-POOL] transcript prefix mismatch, routing direct", trace)
return nil, errLSTranscriptDrift
default:
delta := parsed.Turns[len(state.History):]
if len(delta) != 1 || delta[0].Role != "user" {
inst.ReleaseConcurrency()
u.logTraceSummary(slog.LevelInfo, "[LS-POOL] unsupported transcript delta, routing direct", trace)
return nil, errLSRouteDirect
}
sendTurn = delta[0]
cascadeID = state.CascadeID
}
trace.NewSession = newSession
trace.CascadeID = cascadeID
buildInputStartedAt := time.Now()
items, media, err := buildLSInputFromTurn(sendTurn, contextPrefix)
trace.BuildInputDuration = time.Since(buildInputStartedAt)
if err != nil {
inst.ReleaseConcurrency()
u.logTraceSummary(slog.LevelWarn, "[LS-POOL] build input failed", trace, "err", err)
return nil, fmt.Errorf("build ls input: %w", err)
}
if len(items) == 0 && len(media) == 0 {
inst.ReleaseConcurrency()
u.logTraceSummary(slog.LevelInfo, "[LS-POOL] empty LS input, routing direct", trace)
return nil, errLSRouteDirect
}
sendReq := map[string]any{
"metadata": buildLSRequestMetadata(),
"cascadeId": cascadeID,
"items": items,
"blocking": false,
}
if len(media) > 0 {
sendReq["media"] = media
}
if cfg := buildCascadeConfig(parsed.Model); cfg != nil {
sendReq["cascadeConfig"] = cfg
}
sendStartedAt := time.Now()
sendCtx, sendCancel := context.WithTimeout(req.Context(), lsSendMessageTimeout)
defer sendCancel()
if _, err := inst.CallUnaryJSON(sendCtx, LSService, "SendUserCascadeMessage", sendReq); err != nil {
trace.SendMessageDuration = time.Since(sendStartedAt)
if newSession {
u.cancelCascade(inst, cascadeID)
}
inst.ReleaseConcurrency()
u.logTraceSummary(slog.LevelWarn, "[LS-POOL] send user message failed", trace, "err", err)
return nil, fmt.Errorf("send user cascade message: %w", err)
}
trace.SendMessageDuration = time.Since(sendStartedAt)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/event-stream"},
"Cache-Control": []string{"no-cache"},
"X-Accel-Buffering": []string{"no"},
},
Body: pr,
Request: req,
}
go func() {
defer inst.ReleaseConcurrency()
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
u.streamCascadeResponse(ctx, inst, cascadeID, pw, trace, func(finalText string) {
u.putSessionState(sessionKey, &cascadeSessionState{
CascadeID: cascadeID,
SystemText: parsed.SystemText,
History: appendModelTurn(cloneConversationTurns(parsed.Turns), finalText),
UpdatedAt: time.Now(),
})
})
}()
return resp, nil
}
func (u *LSPoolUpstream) startCascade(inst *Instance) (string, error) {
resp, err := inst.CallUnaryJSON(context.Background(), LSService, "StartCascade", map[string]any{
"metadata": buildLSRequestMetadata(),
})
if err != nil {
return "", fmt.Errorf("start cascade: %w", err)
}
var decoded struct {
CascadeID string `json:"cascadeId"`
}
if err := json.Unmarshal(resp, &decoded); err != nil {
return "", fmt.Errorf("decode start cascade: %w", err)
}
if decoded.CascadeID == "" {
return "", errors.New("start cascade returned empty cascadeId")
}
return decoded.CascadeID, nil
}
// cancelCascade tells the LS to stop processing a cascade invocation.
// Uses a short timeout — best-effort, don't block shutdown.
func (u *LSPoolUpstream) cancelCascade(inst *Instance, cascadeID string) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_, err := inst.CallUnaryJSON(ctx, LSService, "CancelCascadeInvocation", map[string]any{
"cascadeId": cascadeID,
})
if err != nil {
// Try force stop as fallback
_, _ = inst.CallUnaryJSON(ctx, LSService, "ForceStopCascadeTree", map[string]any{
"cascadeId": cascadeID,
})
}
}
// streamCascadeResponse polls GetCascadeTrajectory with adaptive interval.
// Fast (50ms) when model is generating, slow (150ms) when waiting.
// We also issue an immediate first poll so the first token is not delayed by
// the initial ticker interval.
func (u *LSPoolUpstream) streamCascadeResponse(ctx context.Context, inst *Instance, cascadeID string, w *io.PipeWriter, trace *lsRequestTrace, onDone func(string)) {
const (
fastInterval = 50 * time.Millisecond
slowInterval = 150 * time.Millisecond
maxDuration = 5 * time.Minute
maxIdleTimeout = 30 * time.Second
)
ticker := time.NewTicker(slowInterval)
defer ticker.Stop()
timeout := time.After(maxDuration)
lastText := ""
generating := false
lastProgressAt := time.Time{}
pollOnce := func() bool {
if trace != nil {
trace.PollCount++
}
trajResp, err := inst.CallUnaryJSON(ctx, LSService, "GetCascadeTrajectory", map[string]any{
"cascadeId": cascadeID,
})
if err != nil {
if ctx.Err() != nil {
u.logTraceSummary(slog.LevelInfo, "[LS-POOL] request context canceled", trace)
_ = w.Close()
return true
}
return false
}
if trace != nil && trace.FirstPollLatency == 0 {
trace.FirstPollLatency = time.Since(trace.StartedAt)
}
state := extractPlannerResponseState(trajResp)
text, isGenerating, status := state.Text, state.Generating, state.Status
if state.ErrorMessage != "" {
u.logTraceSummary(slog.LevelWarn, "[LS-POOL] Cascade terminated with error", trace, "error", state.ErrorMessage)
if isQuotaExhaustedError(state.ErrorMessage) {
_ = w.CloseWithError(fmt.Errorf("%w: %s", errLSQuotaExhausted, state.ErrorMessage))
} else {
_ = w.CloseWithError(errors.New(state.ErrorMessage))
}
return true
}
// Adaptive interval: fast when generating, slow when idle.
if isGenerating && !generating {
ticker.Reset(fastInterval)
generating = true
} else if !isGenerating && generating {
ticker.Reset(slowInterval)
generating = false
}
// Emit new text as SSE.
if text != lastText && len(text) > len(lastText) {
newPart := text[len(lastText):]
sseEvent := buildGeminiSSEChunk(newPart)
if _, err := w.Write([]byte(sseEvent)); err != nil {
u.logTraceSummary(slog.LevelWarn, "[LS-POOL] write SSE failed", trace, "err", err)
_ = w.CloseWithError(err)
return true
}
lastText = text
lastProgressAt = time.Now()
if trace != nil && trace.FirstTextLatency == 0 {
trace.FirstTextLatency = time.Since(trace.StartedAt)
}
}
// Check if done.
if status == "CASCADE_RUN_STATUS_IDLE" && text != "" && !isGenerating {
usage := extractUsageFromTrajectory(trajResp)
if usage != nil {
finalEvent := buildGeminiSSEFinalChunk(usage)
if _, err := w.Write([]byte(finalEvent)); err != nil {
u.logTraceSummary(slog.LevelWarn, "[LS-POOL] write final SSE failed", trace, "err", err)
_ = w.CloseWithError(err)
return true
}
}
if onDone != nil {
onDone(lastText)
}
u.logTraceSummary(slog.LevelInfo, "[LS-POOL] request completed", trace)
_ = w.Close()
return true
}
if !lastProgressAt.IsZero() && time.Since(lastProgressAt) > maxIdleTimeout {
u.logTraceSummary(slog.LevelWarn, "[LS-POOL] No progress, stopping", trace)
_ = w.Close()
return true
}
return false
}
if pollOnce() {
return
}
for {
select {
case <-ctx.Done():
u.logTraceSummary(slog.LevelInfo, "[LS-POOL] request context canceled", trace)
_ = w.Close()
return
case <-timeout:
u.logTraceSummary(slog.LevelWarn, "[LS-POOL] Cascade timeout", trace)
_ = w.Close()
return
case <-ticker.C:
if pollOnce() {
return
}
}
}
}
// ============================================================
// SSE builders — match Gemini v1internal:streamGenerateContent?alt=sse format
// ============================================================
func buildGeminiSSEChunk(text string) string {
// cloudcode-pa v1internal 格式: {"response": {"candidates": [...]}}
chunk := map[string]any{
"response": map[string]any{
"candidates": []map[string]any{
{
"content": map[string]any{
"parts": []map[string]string{{"text": text}},
"role": "model",
},
},
},
},
}
data, _ := json.Marshal(chunk)
return "data: " + string(data) + "\n\n"
}
func buildGeminiSSEFinalChunk(usage map[string]any) string {
chunk := map[string]any{
"response": map[string]any{
"candidates": []map[string]any{
{
"content": map[string]any{
"parts": []map[string]string{{"text": ""}},
"role": "model",
},
"finishReason": "STOP",
},
},
"usageMetadata": usage,
},
}
data, _ := json.Marshal(chunk)
return "data: " + string(data) + "\n\n"
}
// ============================================================
// Trajectory parsing
// ============================================================
type cascadePlannerState struct {
Text string
Generating bool
Status string
ErrorMessage string
}
func extractPlannerResponseState(trajResp []byte) cascadePlannerState {
var raw map[string]any
if err := json.Unmarshal(trajResp, &raw); err != nil {
return cascadePlannerState{}
}
state := cascadePlannerState{}
state.Status, _ = raw["status"].(string)
state.ErrorMessage = findCascadeErrorMessage(raw)
traj, ok := raw["trajectory"].(map[string]any)
if !ok {
return state
}
steps, ok := traj["steps"].([]any)
if !ok {
return state
}
for _, s := range steps {
sm, ok := s.(map[string]any)
if !ok {
continue
}
if sm["type"] != "CORTEX_STEP_TYPE_PLANNER_RESPONSE" {
continue
}
if sm["status"] == "CORTEX_STEP_STATUS_GENERATING" {
state.Generating = true
}
if pr, ok := sm["plannerResponse"].(map[string]any); ok {
if r, ok := pr["response"].(string); ok {
state.Text = r
}
}
}
return state
}
func extractPlannerResponseText(trajResp []byte) (text string, generating bool, status string) {
state := extractPlannerResponseState(trajResp)
return state.Text, state.Generating, state.Status
}
func findCascadeErrorMessage(value any) string {
switch v := value.(type) {
case map[string]any:
if msg := summarizeCascadeErrorMap(v); msg != "" {
return msg
}
for _, child := range v {
if msg := findCascadeErrorMessage(child); msg != "" {
return msg
}
}
case []any:
for _, child := range v {
if msg := findCascadeErrorMessage(child); msg != "" {
return msg
}
}
}
return ""
}
func summarizeCascadeErrorMap(m map[string]any) string {
if full := cascadeStringField(m, "fullError"); full != "" {
return full
}
if user := cascadeStringField(m, "userErrorMessage"); user != "" {
return user
}
_, hasErrorCode := m["errorCode"]
short := cascadeStringField(m, "shortError")
details := cascadeStringField(m, "details")
message := cascadeStringField(m, "message")
if hasErrorCode {
parts := make([]string, 0, 3)
if short != "" {
parts = append(parts, short)
}
if message != "" && message != short {
parts = append(parts, message)
}
if details != "" && details != short && details != message {
parts = append(parts, details)
}
if len(parts) > 0 {
return strings.Join(parts, ": ")
}
return fmt.Sprintf("cascade error: %v", m["errorCode"])
}
if reason := cascadeStringField(m, "terminationReason"); strings.Contains(strings.ToUpper(reason), "ERROR") {
if message != "" {
return message
}
if short != "" {
return short
}
}
return ""
}
func cascadeStringField(m map[string]any, key string) string {
raw, ok := m[key]
if !ok {
return ""
}
str, ok := raw.(string)
if !ok {
return ""
}
return strings.TrimSpace(str)
}
func extractUsageFromTrajectory(trajResp []byte) map[string]any {
var raw map[string]any
if err := json.Unmarshal(trajResp, &raw); err != nil {
return nil
}
traj, ok := raw["trajectory"].(map[string]any)
if !ok {
return nil
}
steps, ok := traj["steps"].([]any)
if !ok {
return nil
}
for _, s := range steps {
sm, ok := s.(map[string]any)
if !ok {
continue
}
if sm["type"] != "CORTEX_STEP_TYPE_PLANNER_RESPONSE" {
continue
}
meta, ok := sm["metadata"].(map[string]any)
if !ok {
continue
}
mu, ok := meta["modelUsage"].(map[string]any)
if !ok {
continue
}
input, _ := mu["inputTokens"].(string)
output, _ := mu["outputTokens"].(string)
inputN, _ := strconv.Atoi(input)
outputN, _ := strconv.Atoi(output)
if inputN > 0 || outputN > 0 {
return map[string]any{
"promptTokenCount": inputN,
"candidatesTokenCount": outputN,
"totalTokenCount": inputN + outputN,
}
}
}
return nil
}
// ============================================================
// Request parsing — dynamic model, no hardcoding
// ============================================================
func parseGeminiRequest(body []byte) (*geminiParsedRequest, error) {
var envelope geminiEnvelope
if err := json.Unmarshal(body, &envelope); err != nil {
return nil, err
}
reqBody := body
if len(envelope.Request) > 0 {
reqBody = envelope.Request
}
var payload geminiRequestPayload
if err := json.Unmarshal(reqBody, &payload); err != nil {
return nil, err
}
parsed := &geminiParsedRequest{
Model: envelope.Model,
SessionID: payload.SessionID,
ResponseModalities: append([]string(nil), payload.GenerationConfig.GetResponseModalities()...),
HasImageConfig: payload.GenerationConfig != nil && len(bytes.TrimSpace(payload.GenerationConfig.ImageConfig)) > 0 && string(bytes.TrimSpace(payload.GenerationConfig.ImageConfig)) != "null",
}
if parsed.Model == "" {
var top map[string]json.RawMessage
if err := json.Unmarshal(body, &top); err == nil {
_ = json.Unmarshal(top["model"], &parsed.Model)
}
}
if payload.SystemInstruction != nil {
parsed.SystemText = collectTextParts(payload.SystemInstruction.Parts)
}
for _, content := range payload.Contents {
turn := geminiConversationTurn{Role: normalizeTurnRole(content.Role)}
for _, part := range content.Parts {
switch {
case part.Thought || part.ThoughtSignature != "":
parsed.HasUnsupported = true
case len(part.FunctionCall) > 0 || len(part.FunctionResponse) > 0:
parsed.HasUnsupported = true
case part.InlineData != nil:
turn.Parts = append(turn.Parts, geminiConversationPart{
Kind: "media",
MimeType: part.InlineData.MimeType,
Data: part.InlineData.Data,
})
case part.Text != "":
turn.Parts = append(turn.Parts, geminiConversationPart{
Kind: "text",
Text: part.Text,
})
}
}
if len(turn.Parts) > 0 {
parsed.Turns = append(parsed.Turns, turn)
}
}
return parsed, nil
}
func collectTextParts(parts []geminiWirePart) string {
var texts []string
for _, part := range parts {
if part.Text != "" {
texts = append(texts, part.Text)
}
}
return strings.Join(texts, "\n")
}
func (g *geminiWireGenerationConfig) GetResponseModalities() []string {
if g == nil {
return nil
}
return g.ResponseModalities
}
func normalizeTurnRole(role string) string {
if strings.EqualFold(strings.TrimSpace(role), "model") {
return "model"
}
return "user"
}
func decideJSParityRoute(parsed *geminiParsedRequest, body []byte) lsRouteDecision {
if parsed == nil {
return lsRouteDecision{Reason: "nil parsed request"}
}
if requestHasTools(body) {
return lsRouteDecision{Reason: "tools are not supported through cascade"}
}
if parsed.SessionID == "" {
return lsRouteDecision{Reason: "missing sessionId"}
}
if parsed.HasUnsupported {
return lsRouteDecision{Reason: "request contains unsupported Gemini parts"}
}
if isImageGenerationModelName(parsed.Model) {
return lsRouteDecision{Reason: "image generation model"}
}
if parsed.HasImageConfig {
return lsRouteDecision{Reason: "request has imageConfig"}
}
for _, modality := range parsed.ResponseModalities {
if strings.EqualFold(strings.TrimSpace(modality), "IMAGE") {
return lsRouteDecision{Reason: "responseModalities contains IMAGE"}
}
}
if len(parsed.Turns) == 0 {
return lsRouteDecision{Reason: "empty conversation"}
}
return lsRouteDecision{UseLS: true, Reason: "js-parity cascade chat"}
}
func extractPromptAndModel(body []byte) (string, string) {
var outer map[string]json.RawMessage
if err := json.Unmarshal(body, &outer); err != nil {
return "", ""
}
var model string
if m, ok := outer["model"]; ok {
json.Unmarshal(m, &model)
}
if reqRaw, ok := outer["request"]; ok {
return extractPromptFromGeminiRequest(reqRaw), model
}
return extractPromptFromGeminiRequest(body), model
}
func extractPromptFromGeminiRequest(data []byte) string {
var req struct {
Contents []struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
Role string `json:"role"`
} `json:"contents"`
SystemInstruction *struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
} `json:"systemInstruction"`
}
if err := json.Unmarshal(data, &req); err != nil {
return ""
}
var parts []string
// Include system instruction if present
if req.SystemInstruction != nil {
for _, p := range req.SystemInstruction.Parts {
if p.Text != "" {
parts = append(parts, "[System]\n"+p.Text)
}
}
}
// Include full conversation history
for _, c := range req.Contents {
role := c.Role
if role == "" {
role = "user"
}
for _, p := range c.Parts {
if p.Text != "" {
if role == "model" {
parts = append(parts, "[Assistant]\n"+p.Text)
} else {
parts = append(parts, "[User]\n"+p.Text)
}
}
}
}
if len(parts) == 0 {
return ""
}
// If only one part and no system instruction, return raw text (simple case)
if len(parts) == 1 && req.SystemInstruction == nil {
text := parts[0]
// Strip the [User]\n prefix for simple single-message case
if strings.HasPrefix(text, "[User]\n") {
return strings.TrimPrefix(text, "[User]\n")
}
return text
}
return strings.Join(parts, "\n\n")
}
func buildLSInputFromTurn(turn geminiConversationTurn, contextPrefix string) ([]map[string]any, []map[string]any, error) {
items := make([]map[string]any, 0, len(turn.Parts)+1)
media := make([]map[string]any, 0)
if strings.TrimSpace(contextPrefix) != "" {
items = append(items, map[string]any{"text": contextPrefix})
}
for _, part := range turn.Parts {
switch part.Kind {
case "text":
if part.Text != "" {
items = append(items, map[string]any{"text": part.Text})
}
case "media":
decoded, err := base64.StdEncoding.DecodeString(part.Data)
if err != nil {
return nil, nil, fmt.Errorf("decode inlineData: %w", err)
}
media = append(media, map[string]any{
"mimeType": part.MimeType,
"inlineData": decoded,
})
}
}
return items, media, nil
}
func renderConversationContext(systemText string, turns []geminiConversationTurn) string {
var parts []string
if strings.TrimSpace(systemText) != "" {
parts = append(parts, "[System]\n"+strings.TrimSpace(systemText))
}
for _, turn := range turns {
var rendered []string
for _, part := range turn.Parts {
switch part.Kind {
case "text":
if strings.TrimSpace(part.Text) != "" {
rendered = append(rendered, part.Text)
}
case "media":
label := "attachment"
switch {
case strings.HasPrefix(part.MimeType, "image/"):
label = "image attachment"
case strings.HasPrefix(part.MimeType, "video/"):
label = "video attachment"
case strings.HasPrefix(part.MimeType, "audio/"):
label = "audio attachment"
}
rendered = append(rendered, fmt.Sprintf("[%s: %s]", label, part.MimeType))
}
}
if len(rendered) == 0 {
continue
}
roleLabel := "User"
if turn.Role == "model" {
roleLabel = "Assistant"
}
parts = append(parts, fmt.Sprintf("[%s]\n%s", roleLabel, strings.Join(rendered, "\n")))
}
return strings.Join(parts, "\n\n")
}
func buildCascadeConfig(model string) map[string]any {
normalizedModel := normalizeRequestedModelName(model)
if normalizedModel == "" {
return nil
}
modelEnum := resolveModelEnum(normalizedModel)
return map[string]any{
"plannerConfig": map[string]any{
"requestedModel": map[string]any{
"model": modelEnum,
},
},
}
}
func buildLSRequestMetadata() map[string]any {
return map[string]any{
"ideName": "antigravity",
"ideVersion": "1.107.0",
}
}
func appendModelTurn(turns []geminiConversationTurn, modelText string) []geminiConversationTurn {
if strings.TrimSpace(modelText) == "" {
return turns
}
return append(turns, geminiConversationTurn{
Role: "model",
Parts: []geminiConversationPart{{
Kind: "text",
Text: modelText,
}},
})
}
func cloneConversationTurns(src []geminiConversationTurn) []geminiConversationTurn {
out := make([]geminiConversationTurn, 0, len(src))
for _, turn := range src {
copied := geminiConversationTurn{
Role: turn.Role,
Parts: append([]geminiConversationPart(nil), turn.Parts...),
}
out = append(out, copied)
}
return out
}
func conversationPrefixEqual(full, prefix []geminiConversationTurn) bool {
if len(prefix) > len(full) {
return false
}
for i := range prefix {
if prefix[i].Role != full[i].Role {
return false
}
if len(prefix[i].Parts) != len(full[i].Parts) {
return false
}
for j := range prefix[i].Parts {
if prefix[i].Parts[j] != full[i].Parts[j] {
return false
}
}
}
return true
}
// ResolveModelEnumPublic is the exported version of resolveModelEnum for testing.
func ResolveModelEnumPublic(model string) int {
return resolveModelEnum(model)
}
// resolveModelEnum maps a Gemini/Claude model name to its proto enum number.
// Priority: dynamic mapping (from LS) > static fallback.
// The LS uses MODEL_PLACEHOLDER_Mn enum values (1000+n) that are dynamically
// assigned by the server — only these are guaranteed to work.
func resolveModelEnum(model string) int {
model = normalizeRequestedModelName(model)
// 1. Try dynamic mapping first (populated by RefreshModelMapping from LS)
dynamicModelMapMu.RLock()
// Exact match
if v, ok := dynamicModelMap[model]; ok {
dynamicModelMapMu.RUnlock()
return v
}
// Fuzzy match: normalized label vs model name
for label, v := range dynamicModelMap {
if labelMatchesModel(label, model) {
dynamicModelMapMu.RUnlock()
return v
}
}
// Prefix match in dynamic map
for label, v := range dynamicModelMap {
normalized := normalizeLabel(label)
if strings.HasPrefix(model, normalized) || strings.HasPrefix(normalized, model) {
dynamicModelMapMu.RUnlock()
return v
}
}
dynamicModelMapMu.RUnlock()
// 2. Known working placeholders (verified on Mac with LS v1.107.0)
// These map display labels to MODEL_PLACEHOLDER_Mn enum values
knownPlaceholders := map[string]int{
"gemini-3-flash": 1047,
"gemini-3.1-pro-high": 1037,
"gemini-3.1-pro-low": 1036,
"claude-sonnet-4-6-thinking": 1035,
"claude-opus-4-6-thinking": 1026,
"gpt-oss-120b-medium": 342,
}
if v, ok := knownPlaceholders[model]; ok {
return v
}
// Fuzzy match known placeholders
modelLower := strings.ToLower(model)
for k, v := range knownPlaceholders {
if strings.Contains(modelLower, strings.ToLower(k)) || strings.Contains(strings.ToLower(k), modelLower) {
return v
}
}
// 3. Family-based fallback from known placeholders
for k, v := range knownPlaceholders {
if strings.Contains(modelLower, "claude") && strings.Contains(k, "claude") {
return v
}
if strings.Contains(modelLower, "gemini") && strings.Contains(k, "gemini") {
return v
}
if strings.Contains(modelLower, "gpt") && strings.Contains(k, "gpt") {
return v
}
}
// 4. Also check dynamic map if available
dynamicModelMapMu.RLock()
defer dynamicModelMapMu.RUnlock()
for label, v := range dynamicModelMap {
labelLower := strings.ToLower(normalizeLabel(label))
// Same family: "claude" matches "claude-*", "gemini" matches "gemini-*"
if strings.Contains(modelLower, "claude") && strings.Contains(labelLower, "claude") {
return v
}
if strings.Contains(modelLower, "gemini") && strings.Contains(labelLower, "gemini") {
return v
}
if strings.Contains(modelLower, "gpt") && strings.Contains(labelLower, "gpt") {
return v
}
}
// Last resort: return first available model from dynamic map
for _, v := range dynamicModelMap {
return v
}
// No dynamic mapping at all (LS not started yet?) — use gemini-2.5-flash static
return 312
}
// labelMatchesModel does fuzzy matching between LS display label and sub2api model name.
// e.g. "Gemini 3 Flash" matches "gemini-3-flash", "Claude Sonnet 4.6 (Thinking)" matches "claude-sonnet-4-6-thinking"
func labelMatchesModel(label, model string) bool {
normalize := func(s string) string {
s = strings.ToLower(s)
s = strings.ReplaceAll(s, " ", "-")
s = strings.ReplaceAll(s, ".", "-")
s = strings.ReplaceAll(s, "(", "")
s = strings.ReplaceAll(s, ")", "")
s = strings.ReplaceAll(s, "--", "-")
return strings.TrimRight(s, "-")
}
return normalize(label) == normalize(model)
}
// Dynamic model mapping — refreshed from LS at startup
var (
dynamicModelMapMu sync.RWMutex
dynamicModelMap = map[string]int{} // label -> enum value
)
// HasDynamicModelMappingPublic is exported for testing.
func HasDynamicModelMappingPublic() bool {
return hasDynamicModelMapping()
}
// hasDynamicModelMapping returns true if at least one model has been loaded from the LS.
func hasDynamicModelMapping() bool {
dynamicModelMapMu.RLock()
defer dynamicModelMapMu.RUnlock()
return len(dynamicModelMap) > 0
}
// RefreshModelMapping queries the LS for available models and builds the mapping.
// Called automatically when an LS instance starts.
func RefreshModelMapping(inst *Instance) bool {
if inst == nil {
return false
}
startedAt := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), lsModelConfigTimeout)
defer cancel()
resp, err := inst.CallUnaryJSON(ctx, LSService, "GetCascadeModelConfigData", map[string]any{})
if err != nil {
inst.SetModelMappingReady(false)
slog.Warn("[LS-POOL] Failed to get model config",
"account", inst.AccountID,
"replica", inst.Replica,
"address", inst.Address,
"elapsed", time.Since(startedAt).Truncate(time.Millisecond),
"err", err)
return false
}
var data struct {
ClientModelConfigs []struct {
Label string `json:"label"`
ModelOrAlias map[string]any `json:"modelOrAlias"`
} `json:"clientModelConfigs"`
}
if err := json.Unmarshal(resp, &data); err != nil {
inst.SetModelMappingReady(false)
return false
}
newMap := make(map[string]int)
for _, cfg := range data.ClientModelConfigs {
label := cfg.Label
if label == "" {
continue
}
// modelOrAlias is {"model": "MODEL_PLACEHOLDER_M37"} in JSON
modelStr, _ := cfg.ModelOrAlias["model"].(string)
if modelStr == "" {
continue
}
// Parse "MODEL_PLACEHOLDER_M37" → 1037
enumVal := parseModelEnumString(modelStr)
if enumVal > 0 {
// Store both the display label and a normalized form
newMap[label] = enumVal
// Also store kebab-case version: "Gemini 3 Flash" → "gemini-3-flash"
normalized := normalizeLabel(label)
if normalized != "" {
newMap[normalized] = enumVal
}
}
}
if len(newMap) > 0 {
dynamicModelMapMu.Lock()
dynamicModelMap = newMap
dynamicModelMapMu.Unlock()
inst.SetModelMappingReady(true)
slog.Info("[LS-POOL] Model mapping refreshed",
"account", inst.AccountID,
"replica", inst.Replica,
"address", inst.Address,
"count", len(newMap)/2,
"elapsed", time.Since(startedAt).Truncate(time.Millisecond))
return true
}
inst.SetModelMappingReady(false)
return false
}
func parseModelEnumString(s string) int {
// Named enums
named := map[string]int{
"MODEL_CLAUDE_4_SONNET": 281,
"MODEL_CLAUDE_4_SONNET_THINKING": 282,
"MODEL_CLAUDE_4_OPUS": 290,
"MODEL_CLAUDE_4_OPUS_THINKING": 291,
"MODEL_CLAUDE_4_5_SONNET": 333,
"MODEL_CLAUDE_4_5_SONNET_THINKING": 334,
"MODEL_CLAUDE_4_5_HAIKU": 340,
"MODEL_CLAUDE_4_5_HAIKU_THINKING": 341,
"MODEL_OPENAI_GPT_OSS_120B_MEDIUM": 342,
"MODEL_GOOGLE_GEMINI_2_5_FLASH": 312,
"MODEL_GOOGLE_GEMINI_2_5_FLASH_THINKING": 313,
"MODEL_GOOGLE_GEMINI_2_5_FLASH_LITE": 330,
"MODEL_GOOGLE_GEMINI_2_5_PRO": 246,
}
if v, ok := named[s]; ok {
return v
}
// "MODEL_PLACEHOLDER_M37" → 1037
if strings.HasPrefix(s, "MODEL_PLACEHOLDER_M") {
numStr := strings.TrimPrefix(s, "MODEL_PLACEHOLDER_M")
n, err := strconv.Atoi(numStr)
if err == nil {
return 1000 + n
}
}
return 0
}
func normalizeLabel(label string) string {
s := strings.ToLower(label)
s = strings.ReplaceAll(s, " ", "-")
s = strings.ReplaceAll(s, ".", "-")
s = strings.ReplaceAll(s, "(", "")
s = strings.ReplaceAll(s, ")", "")
s = strings.ReplaceAll(s, "--", "-")
return strings.TrimRight(s, "-")
}
func normalizeRequestedModelName(model string) string {
normalized := strings.ToLower(strings.TrimSpace(model))
normalized = strings.TrimPrefix(normalized, "models/")
return normalized
}
func isGeminiPlannerModel(model string) bool {
return strings.Contains(normalizeRequestedModelName(model), "gemini")
}
func systemTextCompatible(stored, current string) bool {
stored = strings.TrimSpace(stored)
current = strings.TrimSpace(current)
return current == "" || current == stored
}
// ============================================================
// Helpers
// ============================================================
func buildSessionCacheKey(accountID int64, namespace, sessionID string) string {
return fmt.Sprintf("%d:%s:%s", accountID, namespace, sessionID)
}
func userNamespace(req *http.Request) string {
if req == nil {
return "anon"
}
for _, value := range []string{
req.Header.Get(userNamespaceHeader),
req.Header.Get("X-Api-Key"),
req.Header.Get("X-Goog-Api-Key"),
req.Header.Get("Authorization"),
} {
if strings.TrimSpace(value) != "" {
sum := sha256.Sum256([]byte(value))
return fmt.Sprintf("%x", sum[:8])
}
}
return "anon"
}
func (u *LSPoolUpstream) getSessionState(key string) *cascadeSessionState {
u.sessionMu.Lock()
defer u.sessionMu.Unlock()
u.pruneExpiredSessionsLocked()
state := u.sessions[key]
if state == nil {
return nil
}
cloned := &cascadeSessionState{
CascadeID: state.CascadeID,
SystemText: state.SystemText,
History: cloneConversationTurns(state.History),
UpdatedAt: state.UpdatedAt,
}
return cloned
}
func (u *LSPoolUpstream) putSessionState(key string, state *cascadeSessionState) {
if state == nil {
return
}
u.sessionMu.Lock()
defer u.sessionMu.Unlock()
u.pruneExpiredSessionsLocked()
u.sessions[key] = &cascadeSessionState{
CascadeID: state.CascadeID,
SystemText: state.SystemText,
History: cloneConversationTurns(state.History),
UpdatedAt: state.UpdatedAt,
}
}
func (u *LSPoolUpstream) pruneExpiredSessionsLocked() {
now := time.Now()
for key, state := range u.sessions {
if state == nil || now.Sub(state.UpdatedAt) > sessionStateTTL {
delete(u.sessions, key)
}
}
}
func isStreamGenerate(path string) bool {
return strings.Contains(path, "streamGenerateContent")
}
// isQuotaExhaustedError detects 429 QUOTA_EXHAUSTED errors from LS cascade trajectory.
// When detected, the caller should fall back to direct HTTP so the gateway can
// inject enabledCreditTypes for AI Credits retry.
func isQuotaExhaustedError(msg string) bool {
lower := strings.ToLower(msg)
return (strings.Contains(lower, "resource_exhausted") || strings.Contains(lower, "quota_exhausted")) &&
(strings.Contains(lower, "429") || strings.Contains(lower, "exhausted your capacity"))
}
func isImageGenerationModelName(model string) bool {
modelLower := normalizeRequestedModelName(model)
return modelLower == "gemini-3.1-flash-image" ||
modelLower == "gemini-3.1-flash-image-preview" ||
strings.HasPrefix(modelLower, "gemini-3.1-flash-image-") ||
modelLower == "gemini-3-pro-image" ||
modelLower == "gemini-3-pro-image-preview" ||
strings.HasPrefix(modelLower, "gemini-3-pro-image-") ||
modelLower == "gemini-2.5-flash-image" ||
modelLower == "gemini-2.5-flash-image-preview" ||
strings.HasPrefix(modelLower, "gemini-2.5-flash-image-")
}
// requestHasTools checks if the Gemini request body contains tools/function declarations.
// These are not supported through the Cascade path and must use direct HTTP.
func requestHasTools(body []byte) bool {
// Check both the wrapped format {"request": {"tools": [...]}} and direct {"tools": [...]}
var outer map[string]json.RawMessage
if err := json.Unmarshal(body, &outer); err != nil {
return false
}
// Check in wrapped request
if reqRaw, ok := outer["request"]; ok {
var inner map[string]json.RawMessage
if json.Unmarshal(reqRaw, &inner) == nil {
if tools, ok := inner["tools"]; ok && len(tools) > 4 { // > "[]" or "null"
return true
}
}
}
// Check at top level
if tools, ok := outer["tools"]; ok && len(tools) > 4 {
return true
}
return false
}
func snapshotRequestBody(req *http.Request) ([]byte, error) {
if req.Body == nil {
return nil, nil
}
body, err := io.ReadAll(req.Body)
if err != nil {
return nil, err
}
req.Body.Close()
req.Body = io.NopCloser(bytes.NewReader(body))
return body, nil
}
// unused but needed for compilation
var _ sync.Mutex