1646 lines
48 KiB
Go
1646 lines
48 KiB
Go
// Package lspool provides an HTTPUpstream adapter that routes
|
|
// streamGenerateContent requests through real Language Server instances.
|
|
//
|
|
// Flow:
|
|
//
|
|
// sub2api → LSPoolUpstream.Do() → StartCascade → SendUserCascadeMessage
|
|
// → LS internally calls cloudcode-pa (with authentic TLS fingerprint)
|
|
// → Poll GetCascadeTrajectory for incremental text
|
|
// → Format as SSE and stream back to sub2api service layer
|
|
//
|
|
// The model is extracted from the original request body, not hardcoded.
|
|
package lspool
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
|
)
|
|
|
|
// Upstream is the interface matching service.HTTPUpstream
|
|
type Upstream interface {
|
|
Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error)
|
|
DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error)
|
|
}
|
|
|
|
// LSPoolUpstream wraps an existing HTTPUpstream and intercepts
|
|
// streamGenerateContent requests to route them through the LS pool.
|
|
type LSPoolUpstream struct {
|
|
pool Backend
|
|
fallback Upstream
|
|
logger *slog.Logger
|
|
sessionMu sync.Mutex
|
|
sessions map[string]*cascadeSessionState
|
|
}
|
|
|
|
// NewLSPoolUpstream creates an LS pool upstream wrapper.
|
|
func NewLSPoolUpstream(pool Backend, fallback Upstream) *LSPoolUpstream {
|
|
return &LSPoolUpstream{
|
|
pool: pool,
|
|
fallback: fallback,
|
|
logger: slog.Default().With("component", "lspool-upstream"),
|
|
sessions: make(map[string]*cascadeSessionState),
|
|
}
|
|
}
|
|
|
|
const (
|
|
userNamespaceHeader = "X-Sub2API-User-Key"
|
|
useAICreditsHeader = "X-Antigravity-Use-AI-Credits"
|
|
availableCreditsHeader = "X-Antigravity-Available-Credits"
|
|
minimumCreditAmountHeader = "X-Antigravity-Minimum-Credit-Amount"
|
|
sessionStateTTL = 30 * time.Minute
|
|
lsSendMessageTimeout = 20 * time.Second
|
|
lsModelConfigTimeout = 20 * time.Second
|
|
)
|
|
|
|
var (
|
|
errLSRouteDirect = errors.New("request should use direct upstream")
|
|
errLSTranscriptDrift = errors.New("request transcript diverged from cached cascade session")
|
|
errLSQuotaExhausted = errors.New("ls cascade returned quota exhausted")
|
|
errLSModelMapPending = errors.New("model mapping not ready")
|
|
)
|
|
|
|
// IsLSQuotaExhaustedError reports whether err originated from an LS cascade
|
|
// quota/capacity exhaustion signal.
|
|
func IsLSQuotaExhaustedError(err error) bool {
|
|
return errors.Is(err, errLSQuotaExhausted)
|
|
}
|
|
|
|
// LSQuotaExhaustedMessage extracts the original LS error message, if present.
|
|
func LSQuotaExhaustedMessage(err error) string {
|
|
if err == nil {
|
|
return ""
|
|
}
|
|
msg := strings.TrimSpace(err.Error())
|
|
if msg == "" {
|
|
return ""
|
|
}
|
|
prefix := errLSQuotaExhausted.Error()
|
|
if msg == prefix {
|
|
return ""
|
|
}
|
|
if strings.HasPrefix(msg, prefix+":") {
|
|
return strings.TrimSpace(strings.TrimPrefix(msg, prefix+":"))
|
|
}
|
|
return msg
|
|
}
|
|
|
|
type cascadeSessionState struct {
|
|
CascadeID string
|
|
SystemText string
|
|
History []geminiConversationTurn
|
|
UpdatedAt time.Time
|
|
}
|
|
|
|
type geminiEnvelope struct {
|
|
Model string `json:"model"`
|
|
Request json.RawMessage `json:"request"`
|
|
}
|
|
|
|
type geminiRequestPayload struct {
|
|
Contents []geminiWireContent `json:"contents"`
|
|
SystemInstruction *geminiWireContent `json:"systemInstruction,omitempty"`
|
|
GenerationConfig *geminiWireGenerationConfig `json:"generationConfig,omitempty"`
|
|
SessionID string `json:"sessionId,omitempty"`
|
|
}
|
|
|
|
type geminiWireGenerationConfig struct {
|
|
ResponseModalities []string `json:"responseModalities,omitempty"`
|
|
ImageConfig json.RawMessage `json:"imageConfig,omitempty"`
|
|
}
|
|
|
|
type geminiWireContent struct {
|
|
Role string `json:"role"`
|
|
Parts []geminiWirePart `json:"parts"`
|
|
}
|
|
|
|
type geminiWirePart struct {
|
|
Text string `json:"text,omitempty"`
|
|
Thought bool `json:"thought,omitempty"`
|
|
ThoughtSignature string `json:"thoughtSignature,omitempty"`
|
|
InlineData *geminiWireInlineData `json:"inlineData,omitempty"`
|
|
FunctionCall map[string]any `json:"functionCall,omitempty"`
|
|
FunctionResponse map[string]any `json:"functionResponse,omitempty"`
|
|
}
|
|
|
|
type geminiWireInlineData struct {
|
|
MimeType string `json:"mimeType"`
|
|
Data string `json:"data"`
|
|
}
|
|
|
|
type geminiParsedRequest struct {
|
|
Model string
|
|
SessionID string
|
|
SystemText string
|
|
Turns []geminiConversationTurn
|
|
ResponseModalities []string
|
|
HasImageConfig bool
|
|
HasUnsupported bool
|
|
}
|
|
|
|
type geminiConversationTurn struct {
|
|
Role string
|
|
Parts []geminiConversationPart
|
|
}
|
|
|
|
type geminiConversationPart struct {
|
|
Kind string
|
|
Text string
|
|
MimeType string
|
|
Data string
|
|
}
|
|
|
|
type lsRouteDecision struct {
|
|
UseLS bool
|
|
Reason string
|
|
}
|
|
|
|
type lsRequestTrace struct {
|
|
StartedAt time.Time
|
|
AccountID int64
|
|
Model string
|
|
SessionIDHash string
|
|
Replica int
|
|
CascadeID string
|
|
NewSession bool
|
|
InflightAtAcquire int64
|
|
TurnCount int
|
|
GetOrCreateDuration time.Duration
|
|
StartCascadeDuration time.Duration
|
|
BuildInputDuration time.Duration
|
|
SendMessageDuration time.Duration
|
|
FirstPollLatency time.Duration
|
|
FirstTextLatency time.Duration
|
|
PollCount int
|
|
}
|
|
|
|
// Do routes streamGenerateContent through LS, everything else through fallback.
|
|
func (u *LSPoolUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
|
u.extractAndStripInternalHeaders(req, strconv.FormatInt(accountID, 10))
|
|
|
|
if !isStreamGenerate(req.URL.Path) {
|
|
return u.fallback.Do(req, proxyURL, accountID, accountConcurrency)
|
|
}
|
|
body, err := snapshotRequestBody(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("snapshot request body: %w", err)
|
|
}
|
|
if len(bytes.TrimSpace(body)) == 0 {
|
|
return u.fallback.Do(req, proxyURL, accountID, accountConcurrency)
|
|
}
|
|
|
|
resp, err := u.doViaLS(req, body, accountID, proxyURL)
|
|
if err != nil {
|
|
if shouldFallbackDirect(err) {
|
|
u.logger.Warn("[LS-POOL] LS fell back to direct", "account", accountID, "err", err)
|
|
req.Body = io.NopCloser(bytes.NewReader(body))
|
|
return u.fallback.Do(req, proxyURL, accountID, accountConcurrency)
|
|
}
|
|
return nil, err
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
// DoWithTLS — LS handles its own TLS, so profile is ignored for LS requests.
|
|
func (u *LSPoolUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) {
|
|
u.extractAndStripInternalHeaders(req, strconv.FormatInt(accountID, 10))
|
|
|
|
if !isStreamGenerate(req.URL.Path) {
|
|
return u.fallback.DoWithTLS(req, proxyURL, accountID, accountConcurrency, profile)
|
|
}
|
|
body, err := snapshotRequestBody(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("snapshot request body: %w", err)
|
|
}
|
|
if len(bytes.TrimSpace(body)) == 0 {
|
|
return u.fallback.DoWithTLS(req, proxyURL, accountID, accountConcurrency, profile)
|
|
}
|
|
|
|
resp, err := u.doViaLS(req, body, accountID, proxyURL)
|
|
if err != nil {
|
|
if shouldFallbackDirect(err) {
|
|
u.logger.Warn("[LS-POOL] LS fell back to direct+TLS", "account", accountID, "err", err)
|
|
req.Body = io.NopCloser(bytes.NewReader(body))
|
|
return u.fallback.DoWithTLS(req, proxyURL, accountID, accountConcurrency, profile)
|
|
}
|
|
return nil, err
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
func (u *LSPoolUpstream) doViaLS(req *http.Request, body []byte, accountID int64, proxyURL string) (*http.Response, error) {
|
|
accountKey := strconv.FormatInt(accountID, 10)
|
|
|
|
if CurrentLSStrategy() != LSStrategyJSParity {
|
|
return u.forwardDirectWithKeepalive(req, body, accountKey, accountID, proxyURL)
|
|
}
|
|
|
|
parsed, err := parseGeminiRequest(body)
|
|
if err != nil {
|
|
return u.forwardDirect(req, body, proxyURL, accountID, "parse request failed")
|
|
}
|
|
|
|
decision := decideJSParityRoute(parsed, body)
|
|
if !decision.UseLS {
|
|
return u.forwardDirect(req, body, proxyURL, accountID, decision.Reason)
|
|
}
|
|
|
|
resp, err := u.forwardChatViaLS(req, body, parsed, accountKey, accountID, proxyURL)
|
|
if err != nil {
|
|
if shouldFallbackDirect(err) {
|
|
return u.forwardDirect(req, body, proxyURL, accountID, err.Error())
|
|
}
|
|
return nil, err
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
func shouldFallbackDirect(err error) bool {
|
|
return errors.Is(err, errLSRouteDirect) || errors.Is(err, errLSTranscriptDrift)
|
|
}
|
|
|
|
func (u *LSPoolUpstream) forwardDirectWithKeepalive(req *http.Request, body []byte, accountKey string, accountID int64, proxyURL string) (*http.Response, error) {
|
|
// Start/reuse LS instance — keeps heartbeat alive, authenticates with
|
|
// cloudcode-pa, and refreshes model mapping. The LS process itself is NOT
|
|
// used as a proxy; we forward the original HTTP request directly to
|
|
// cloudcode-pa, bypassing Cascade entirely. This avoids the IDE agent
|
|
// system prompt that Cascade injects.
|
|
_, err := u.pool.GetOrCreate(accountKey, "", proxyURL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get LS instance: %w", err)
|
|
}
|
|
|
|
u.logger.Info("[LS-POOL] Forwarding via direct HTTP (LS keepalive active)",
|
|
"account", accountID, "path", req.URL.Path)
|
|
|
|
return u.forwardDirect(req, body, proxyURL, accountID, "strategy=direct")
|
|
}
|
|
|
|
func (u *LSPoolUpstream) forwardDirect(req *http.Request, body []byte, proxyURL string, accountID int64, reason string) (*http.Response, error) {
|
|
u.logger.Info("[LS-POOL] Forwarding via direct HTTP",
|
|
"account", accountID,
|
|
"path", req.URL.Path,
|
|
"reason", reason)
|
|
req.Header.Del(userNamespaceHeader)
|
|
req.Body = io.NopCloser(bytes.NewReader(body))
|
|
return u.fallback.Do(req, proxyURL, accountID, 1)
|
|
}
|
|
|
|
func (u *LSPoolUpstream) extractAndStripInternalHeaders(req *http.Request, accountKey string) {
|
|
if auth := req.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") {
|
|
accessToken := strings.TrimPrefix(auth, "Bearer ")
|
|
refreshToken := req.Header.Get("X-Antigravity-Refresh-Token")
|
|
var expiresAt time.Time
|
|
if raw := req.Header.Get("X-Antigravity-Token-Expiry"); raw != "" {
|
|
if parsed, err := time.Parse(time.RFC3339, raw); err == nil {
|
|
expiresAt = parsed
|
|
}
|
|
}
|
|
u.pool.SetAccountToken(accountKey, accessToken, refreshToken, expiresAt)
|
|
}
|
|
|
|
useAICredits, hasUseAICredits := parseBoolHeader(req.Header.Get(useAICreditsHeader))
|
|
availableCredits, hasAvailableCredits := parseOptionalInt32Header(req.Header.Get(availableCreditsHeader))
|
|
minimumCreditAmount, hasMinimumCreditAmount := parseOptionalInt32Header(req.Header.Get(minimumCreditAmountHeader))
|
|
if hasUseAICredits || hasAvailableCredits || hasMinimumCreditAmount {
|
|
u.pool.SetAccountModelCredits(accountKey, useAICredits, availableCredits, minimumCreditAmount)
|
|
}
|
|
|
|
req.Header.Del("X-Antigravity-Refresh-Token")
|
|
req.Header.Del("X-Antigravity-Token-Expiry")
|
|
req.Header.Del(useAICreditsHeader)
|
|
req.Header.Del(availableCreditsHeader)
|
|
req.Header.Del(minimumCreditAmountHeader)
|
|
}
|
|
|
|
func parseBoolHeader(raw string) (bool, bool) {
|
|
raw = strings.TrimSpace(raw)
|
|
if raw == "" {
|
|
return false, false
|
|
}
|
|
val, err := strconv.ParseBool(raw)
|
|
if err != nil {
|
|
return false, false
|
|
}
|
|
return val, true
|
|
}
|
|
|
|
func parseOptionalInt32Header(raw string) (*int32, bool) {
|
|
raw = strings.TrimSpace(raw)
|
|
if raw == "" {
|
|
return nil, false
|
|
}
|
|
val, err := strconv.ParseInt(raw, 10, 32)
|
|
if err != nil {
|
|
return nil, false
|
|
}
|
|
parsed := int32(val)
|
|
return &parsed, true
|
|
}
|
|
|
|
func shortTraceID(raw string) string {
|
|
raw = strings.TrimSpace(raw)
|
|
if raw == "" {
|
|
return "none"
|
|
}
|
|
sum := sha256.Sum256([]byte(raw))
|
|
return fmt.Sprintf("%x", sum[:4])
|
|
}
|
|
|
|
func durationMS(d time.Duration) int64 {
|
|
if d <= 0 {
|
|
return 0
|
|
}
|
|
return d.Milliseconds()
|
|
}
|
|
|
|
func (u *LSPoolUpstream) logTraceSummary(level slog.Level, msg string, trace *lsRequestTrace, extra ...any) {
|
|
if trace == nil {
|
|
u.logger.Log(context.Background(), level, msg, extra...)
|
|
return
|
|
}
|
|
args := []any{
|
|
"account", trace.AccountID,
|
|
"model", trace.Model,
|
|
"session", trace.SessionIDHash,
|
|
"replica", trace.Replica,
|
|
"cascade", shortTraceID(trace.CascadeID),
|
|
"new_session", trace.NewSession,
|
|
"turns", trace.TurnCount,
|
|
"inflight", trace.InflightAtAcquire,
|
|
"get_or_create_ms", durationMS(trace.GetOrCreateDuration),
|
|
"start_cascade_ms", durationMS(trace.StartCascadeDuration),
|
|
"build_input_ms", durationMS(trace.BuildInputDuration),
|
|
"send_message_ms", durationMS(trace.SendMessageDuration),
|
|
"first_poll_ms", durationMS(trace.FirstPollLatency),
|
|
"first_token_ms", durationMS(trace.FirstTextLatency),
|
|
"polls", trace.PollCount,
|
|
"total_ms", durationMS(time.Since(trace.StartedAt)),
|
|
}
|
|
args = append(args, extra...)
|
|
u.logger.Log(context.Background(), level, msg, args...)
|
|
}
|
|
|
|
func (u *LSPoolUpstream) forwardChatViaLS(req *http.Request, body []byte, parsed *geminiParsedRequest, accountKey string, accountID int64, proxyURL string) (*http.Response, error) {
|
|
trace := &lsRequestTrace{
|
|
StartedAt: time.Now(),
|
|
AccountID: accountID,
|
|
Model: parsed.Model,
|
|
SessionIDHash: shortTraceID(parsed.SessionID),
|
|
TurnCount: len(parsed.Turns),
|
|
}
|
|
|
|
getOrCreateStartedAt := time.Now()
|
|
sessionKey := buildSessionCacheKey(accountID, userNamespace(req), parsed.SessionID)
|
|
inst, err := u.pool.GetOrCreate(accountKey, sessionKey, proxyURL)
|
|
if err != nil {
|
|
trace.GetOrCreateDuration = time.Since(getOrCreateStartedAt)
|
|
u.logTraceSummary(slog.LevelWarn, "[LS-POOL] get instance failed", trace, "err", err)
|
|
return nil, fmt.Errorf("get LS instance: %w", err)
|
|
}
|
|
trace.GetOrCreateDuration = time.Since(getOrCreateStartedAt)
|
|
trace.Replica = inst.Replica
|
|
if !inst.HasModelMappingReady() {
|
|
u.logTraceSummary(slog.LevelInfo, "[LS-POOL] model mapping pending, routing direct", trace)
|
|
return nil, errLSModelMapPending
|
|
}
|
|
if !inst.AcquireConcurrency() {
|
|
u.logTraceSummary(slog.LevelWarn, "[LS-POOL] instance busy", trace,
|
|
"err", fmt.Sprintf("ls instance busy for account %d", accountID),
|
|
"current_inflight", inst.ConcurrentCount(),
|
|
"max_inflight", maxConcurrencyPerInstance)
|
|
return nil, fmt.Errorf("ls instance busy for account %d", accountID)
|
|
}
|
|
trace.InflightAtAcquire = inst.ConcurrentCount()
|
|
|
|
state := u.getSessionState(sessionKey)
|
|
if state != nil && !systemTextCompatible(state.SystemText, parsed.SystemText) {
|
|
inst.ReleaseConcurrency()
|
|
u.logTraceSummary(slog.LevelInfo, "[LS-POOL] transcript drift, routing direct", trace)
|
|
return nil, errLSTranscriptDrift
|
|
}
|
|
|
|
cascadeID := ""
|
|
newSession := false
|
|
sendTurn := geminiConversationTurn{}
|
|
contextPrefix := ""
|
|
|
|
switch {
|
|
case state == nil:
|
|
if len(parsed.Turns) == 0 {
|
|
inst.ReleaseConcurrency()
|
|
return nil, errLSRouteDirect
|
|
}
|
|
lastTurn := parsed.Turns[len(parsed.Turns)-1]
|
|
if lastTurn.Role != "user" {
|
|
inst.ReleaseConcurrency()
|
|
u.logTraceSummary(slog.LevelInfo, "[LS-POOL] invalid first turn for LS, routing direct", trace)
|
|
return nil, errLSRouteDirect
|
|
}
|
|
sendTurn = lastTurn
|
|
contextPrefix = renderConversationContext(parsed.SystemText, parsed.Turns[:len(parsed.Turns)-1])
|
|
startCascadeStartedAt := time.Now()
|
|
cascadeID, err = u.startCascade(inst)
|
|
trace.StartCascadeDuration = time.Since(startCascadeStartedAt)
|
|
if err != nil {
|
|
inst.ReleaseConcurrency()
|
|
u.logTraceSummary(slog.LevelWarn, "[LS-POOL] start cascade failed", trace, "err", err)
|
|
return nil, err
|
|
}
|
|
newSession = true
|
|
case !conversationPrefixEqual(parsed.Turns, state.History):
|
|
inst.ReleaseConcurrency()
|
|
u.logTraceSummary(slog.LevelInfo, "[LS-POOL] transcript prefix mismatch, routing direct", trace)
|
|
return nil, errLSTranscriptDrift
|
|
default:
|
|
delta := parsed.Turns[len(state.History):]
|
|
if len(delta) != 1 || delta[0].Role != "user" {
|
|
inst.ReleaseConcurrency()
|
|
u.logTraceSummary(slog.LevelInfo, "[LS-POOL] unsupported transcript delta, routing direct", trace)
|
|
return nil, errLSRouteDirect
|
|
}
|
|
sendTurn = delta[0]
|
|
cascadeID = state.CascadeID
|
|
}
|
|
trace.NewSession = newSession
|
|
trace.CascadeID = cascadeID
|
|
|
|
buildInputStartedAt := time.Now()
|
|
items, media, err := buildLSInputFromTurn(sendTurn, contextPrefix)
|
|
trace.BuildInputDuration = time.Since(buildInputStartedAt)
|
|
if err != nil {
|
|
inst.ReleaseConcurrency()
|
|
u.logTraceSummary(slog.LevelWarn, "[LS-POOL] build input failed", trace, "err", err)
|
|
return nil, fmt.Errorf("build ls input: %w", err)
|
|
}
|
|
if len(items) == 0 && len(media) == 0 {
|
|
inst.ReleaseConcurrency()
|
|
u.logTraceSummary(slog.LevelInfo, "[LS-POOL] empty LS input, routing direct", trace)
|
|
return nil, errLSRouteDirect
|
|
}
|
|
sendReq := map[string]any{
|
|
"metadata": buildLSRequestMetadata(),
|
|
"cascadeId": cascadeID,
|
|
"items": items,
|
|
"blocking": false,
|
|
}
|
|
if len(media) > 0 {
|
|
sendReq["media"] = media
|
|
}
|
|
if cfg := buildCascadeConfig(parsed.Model); cfg != nil {
|
|
sendReq["cascadeConfig"] = cfg
|
|
}
|
|
sendStartedAt := time.Now()
|
|
sendCtx, sendCancel := context.WithTimeout(req.Context(), lsSendMessageTimeout)
|
|
defer sendCancel()
|
|
if _, err := inst.CallUnaryJSON(sendCtx, LSService, "SendUserCascadeMessage", sendReq); err != nil {
|
|
trace.SendMessageDuration = time.Since(sendStartedAt)
|
|
if newSession {
|
|
u.cancelCascade(inst, cascadeID)
|
|
}
|
|
inst.ReleaseConcurrency()
|
|
u.logTraceSummary(slog.LevelWarn, "[LS-POOL] send user message failed", trace, "err", err)
|
|
return nil, fmt.Errorf("send user cascade message: %w", err)
|
|
}
|
|
trace.SendMessageDuration = time.Since(sendStartedAt)
|
|
|
|
pr, pw := io.Pipe()
|
|
resp := &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Header: http.Header{
|
|
"Content-Type": []string{"text/event-stream"},
|
|
"Cache-Control": []string{"no-cache"},
|
|
"X-Accel-Buffering": []string{"no"},
|
|
},
|
|
Body: pr,
|
|
Request: req,
|
|
}
|
|
|
|
go func() {
|
|
defer inst.ReleaseConcurrency()
|
|
ctx, cancel := context.WithCancel(req.Context())
|
|
defer cancel()
|
|
|
|
u.streamCascadeResponse(ctx, inst, cascadeID, pw, trace, func(finalText string) {
|
|
u.putSessionState(sessionKey, &cascadeSessionState{
|
|
CascadeID: cascadeID,
|
|
SystemText: parsed.SystemText,
|
|
History: appendModelTurn(cloneConversationTurns(parsed.Turns), finalText),
|
|
UpdatedAt: time.Now(),
|
|
})
|
|
})
|
|
}()
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func (u *LSPoolUpstream) startCascade(inst *Instance) (string, error) {
|
|
resp, err := inst.CallUnaryJSON(context.Background(), LSService, "StartCascade", map[string]any{
|
|
"metadata": buildLSRequestMetadata(),
|
|
})
|
|
if err != nil {
|
|
return "", fmt.Errorf("start cascade: %w", err)
|
|
}
|
|
var decoded struct {
|
|
CascadeID string `json:"cascadeId"`
|
|
}
|
|
if err := json.Unmarshal(resp, &decoded); err != nil {
|
|
return "", fmt.Errorf("decode start cascade: %w", err)
|
|
}
|
|
if decoded.CascadeID == "" {
|
|
return "", errors.New("start cascade returned empty cascadeId")
|
|
}
|
|
return decoded.CascadeID, nil
|
|
}
|
|
|
|
// cancelCascade tells the LS to stop processing a cascade invocation.
|
|
// Uses a short timeout — best-effort, don't block shutdown.
|
|
func (u *LSPoolUpstream) cancelCascade(inst *Instance, cascadeID string) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
defer cancel()
|
|
_, err := inst.CallUnaryJSON(ctx, LSService, "CancelCascadeInvocation", map[string]any{
|
|
"cascadeId": cascadeID,
|
|
})
|
|
if err != nil {
|
|
// Try force stop as fallback
|
|
_, _ = inst.CallUnaryJSON(ctx, LSService, "ForceStopCascadeTree", map[string]any{
|
|
"cascadeId": cascadeID,
|
|
})
|
|
}
|
|
}
|
|
|
|
// streamCascadeResponse polls GetCascadeTrajectory with adaptive interval.
|
|
// Fast (50ms) when model is generating, slow (150ms) when waiting.
|
|
// We also issue an immediate first poll so the first token is not delayed by
|
|
// the initial ticker interval.
|
|
func (u *LSPoolUpstream) streamCascadeResponse(ctx context.Context, inst *Instance, cascadeID string, w *io.PipeWriter, trace *lsRequestTrace, onDone func(string)) {
|
|
const (
|
|
fastInterval = 50 * time.Millisecond
|
|
slowInterval = 150 * time.Millisecond
|
|
maxDuration = 5 * time.Minute
|
|
maxIdleTimeout = 30 * time.Second
|
|
)
|
|
|
|
ticker := time.NewTicker(slowInterval)
|
|
defer ticker.Stop()
|
|
|
|
timeout := time.After(maxDuration)
|
|
lastText := ""
|
|
generating := false
|
|
lastProgressAt := time.Time{}
|
|
|
|
pollOnce := func() bool {
|
|
if trace != nil {
|
|
trace.PollCount++
|
|
}
|
|
trajResp, err := inst.CallUnaryJSON(ctx, LSService, "GetCascadeTrajectory", map[string]any{
|
|
"cascadeId": cascadeID,
|
|
})
|
|
if err != nil {
|
|
if ctx.Err() != nil {
|
|
u.logTraceSummary(slog.LevelInfo, "[LS-POOL] request context canceled", trace)
|
|
_ = w.Close()
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
if trace != nil && trace.FirstPollLatency == 0 {
|
|
trace.FirstPollLatency = time.Since(trace.StartedAt)
|
|
}
|
|
|
|
state := extractPlannerResponseState(trajResp)
|
|
text, isGenerating, status := state.Text, state.Generating, state.Status
|
|
if state.ErrorMessage != "" {
|
|
u.logTraceSummary(slog.LevelWarn, "[LS-POOL] Cascade terminated with error", trace, "error", state.ErrorMessage)
|
|
if isQuotaExhaustedError(state.ErrorMessage) {
|
|
_ = w.CloseWithError(fmt.Errorf("%w: %s", errLSQuotaExhausted, state.ErrorMessage))
|
|
} else {
|
|
_ = w.CloseWithError(errors.New(state.ErrorMessage))
|
|
}
|
|
return true
|
|
}
|
|
|
|
// Adaptive interval: fast when generating, slow when idle.
|
|
if isGenerating && !generating {
|
|
ticker.Reset(fastInterval)
|
|
generating = true
|
|
} else if !isGenerating && generating {
|
|
ticker.Reset(slowInterval)
|
|
generating = false
|
|
}
|
|
|
|
// Emit new text as SSE.
|
|
if text != lastText && len(text) > len(lastText) {
|
|
newPart := text[len(lastText):]
|
|
sseEvent := buildGeminiSSEChunk(newPart)
|
|
if _, err := w.Write([]byte(sseEvent)); err != nil {
|
|
u.logTraceSummary(slog.LevelWarn, "[LS-POOL] write SSE failed", trace, "err", err)
|
|
_ = w.CloseWithError(err)
|
|
return true
|
|
}
|
|
lastText = text
|
|
lastProgressAt = time.Now()
|
|
if trace != nil && trace.FirstTextLatency == 0 {
|
|
trace.FirstTextLatency = time.Since(trace.StartedAt)
|
|
}
|
|
}
|
|
|
|
// Check if done.
|
|
if status == "CASCADE_RUN_STATUS_IDLE" && text != "" && !isGenerating {
|
|
usage := extractUsageFromTrajectory(trajResp)
|
|
if usage != nil {
|
|
finalEvent := buildGeminiSSEFinalChunk(usage)
|
|
if _, err := w.Write([]byte(finalEvent)); err != nil {
|
|
u.logTraceSummary(slog.LevelWarn, "[LS-POOL] write final SSE failed", trace, "err", err)
|
|
_ = w.CloseWithError(err)
|
|
return true
|
|
}
|
|
}
|
|
if onDone != nil {
|
|
onDone(lastText)
|
|
}
|
|
u.logTraceSummary(slog.LevelInfo, "[LS-POOL] request completed", trace)
|
|
_ = w.Close()
|
|
return true
|
|
}
|
|
|
|
if !lastProgressAt.IsZero() && time.Since(lastProgressAt) > maxIdleTimeout {
|
|
u.logTraceSummary(slog.LevelWarn, "[LS-POOL] No progress, stopping", trace)
|
|
_ = w.Close()
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
if pollOnce() {
|
|
return
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
u.logTraceSummary(slog.LevelInfo, "[LS-POOL] request context canceled", trace)
|
|
_ = w.Close()
|
|
return
|
|
case <-timeout:
|
|
u.logTraceSummary(slog.LevelWarn, "[LS-POOL] Cascade timeout", trace)
|
|
_ = w.Close()
|
|
return
|
|
case <-ticker.C:
|
|
if pollOnce() {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// ============================================================
|
|
// SSE builders — match Gemini v1internal:streamGenerateContent?alt=sse format
|
|
// ============================================================
|
|
|
|
func buildGeminiSSEChunk(text string) string {
|
|
// cloudcode-pa v1internal 格式: {"response": {"candidates": [...]}}
|
|
chunk := map[string]any{
|
|
"response": map[string]any{
|
|
"candidates": []map[string]any{
|
|
{
|
|
"content": map[string]any{
|
|
"parts": []map[string]string{{"text": text}},
|
|
"role": "model",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
data, _ := json.Marshal(chunk)
|
|
return "data: " + string(data) + "\n\n"
|
|
}
|
|
|
|
func buildGeminiSSEFinalChunk(usage map[string]any) string {
|
|
chunk := map[string]any{
|
|
"response": map[string]any{
|
|
"candidates": []map[string]any{
|
|
{
|
|
"content": map[string]any{
|
|
"parts": []map[string]string{{"text": ""}},
|
|
"role": "model",
|
|
},
|
|
"finishReason": "STOP",
|
|
},
|
|
},
|
|
"usageMetadata": usage,
|
|
},
|
|
}
|
|
data, _ := json.Marshal(chunk)
|
|
return "data: " + string(data) + "\n\n"
|
|
}
|
|
|
|
// ============================================================
|
|
// Trajectory parsing
|
|
// ============================================================
|
|
|
|
type cascadePlannerState struct {
|
|
Text string
|
|
Generating bool
|
|
Status string
|
|
ErrorMessage string
|
|
}
|
|
|
|
func extractPlannerResponseState(trajResp []byte) cascadePlannerState {
|
|
var raw map[string]any
|
|
if err := json.Unmarshal(trajResp, &raw); err != nil {
|
|
return cascadePlannerState{}
|
|
}
|
|
state := cascadePlannerState{}
|
|
state.Status, _ = raw["status"].(string)
|
|
state.ErrorMessage = findCascadeErrorMessage(raw)
|
|
|
|
traj, ok := raw["trajectory"].(map[string]any)
|
|
if !ok {
|
|
return state
|
|
}
|
|
steps, ok := traj["steps"].([]any)
|
|
if !ok {
|
|
return state
|
|
}
|
|
for _, s := range steps {
|
|
sm, ok := s.(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
if sm["type"] != "CORTEX_STEP_TYPE_PLANNER_RESPONSE" {
|
|
continue
|
|
}
|
|
if sm["status"] == "CORTEX_STEP_STATUS_GENERATING" {
|
|
state.Generating = true
|
|
}
|
|
if pr, ok := sm["plannerResponse"].(map[string]any); ok {
|
|
if r, ok := pr["response"].(string); ok {
|
|
state.Text = r
|
|
}
|
|
}
|
|
}
|
|
return state
|
|
}
|
|
|
|
func extractPlannerResponseText(trajResp []byte) (text string, generating bool, status string) {
|
|
state := extractPlannerResponseState(trajResp)
|
|
return state.Text, state.Generating, state.Status
|
|
}
|
|
|
|
func findCascadeErrorMessage(value any) string {
|
|
switch v := value.(type) {
|
|
case map[string]any:
|
|
if msg := summarizeCascadeErrorMap(v); msg != "" {
|
|
return msg
|
|
}
|
|
for _, child := range v {
|
|
if msg := findCascadeErrorMessage(child); msg != "" {
|
|
return msg
|
|
}
|
|
}
|
|
case []any:
|
|
for _, child := range v {
|
|
if msg := findCascadeErrorMessage(child); msg != "" {
|
|
return msg
|
|
}
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func summarizeCascadeErrorMap(m map[string]any) string {
|
|
if full := cascadeStringField(m, "fullError"); full != "" {
|
|
return full
|
|
}
|
|
if user := cascadeStringField(m, "userErrorMessage"); user != "" {
|
|
return user
|
|
}
|
|
|
|
_, hasErrorCode := m["errorCode"]
|
|
short := cascadeStringField(m, "shortError")
|
|
details := cascadeStringField(m, "details")
|
|
message := cascadeStringField(m, "message")
|
|
|
|
if hasErrorCode {
|
|
parts := make([]string, 0, 3)
|
|
if short != "" {
|
|
parts = append(parts, short)
|
|
}
|
|
if message != "" && message != short {
|
|
parts = append(parts, message)
|
|
}
|
|
if details != "" && details != short && details != message {
|
|
parts = append(parts, details)
|
|
}
|
|
if len(parts) > 0 {
|
|
return strings.Join(parts, ": ")
|
|
}
|
|
return fmt.Sprintf("cascade error: %v", m["errorCode"])
|
|
}
|
|
|
|
if reason := cascadeStringField(m, "terminationReason"); strings.Contains(strings.ToUpper(reason), "ERROR") {
|
|
if message != "" {
|
|
return message
|
|
}
|
|
if short != "" {
|
|
return short
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
func cascadeStringField(m map[string]any, key string) string {
|
|
raw, ok := m[key]
|
|
if !ok {
|
|
return ""
|
|
}
|
|
str, ok := raw.(string)
|
|
if !ok {
|
|
return ""
|
|
}
|
|
return strings.TrimSpace(str)
|
|
}
|
|
|
|
func extractUsageFromTrajectory(trajResp []byte) map[string]any {
|
|
var raw map[string]any
|
|
if err := json.Unmarshal(trajResp, &raw); err != nil {
|
|
return nil
|
|
}
|
|
traj, ok := raw["trajectory"].(map[string]any)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
steps, ok := traj["steps"].([]any)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
for _, s := range steps {
|
|
sm, ok := s.(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
if sm["type"] != "CORTEX_STEP_TYPE_PLANNER_RESPONSE" {
|
|
continue
|
|
}
|
|
meta, ok := sm["metadata"].(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
mu, ok := meta["modelUsage"].(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
input, _ := mu["inputTokens"].(string)
|
|
output, _ := mu["outputTokens"].(string)
|
|
inputN, _ := strconv.Atoi(input)
|
|
outputN, _ := strconv.Atoi(output)
|
|
if inputN > 0 || outputN > 0 {
|
|
return map[string]any{
|
|
"promptTokenCount": inputN,
|
|
"candidatesTokenCount": outputN,
|
|
"totalTokenCount": inputN + outputN,
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ============================================================
|
|
// Request parsing — dynamic model, no hardcoding
|
|
// ============================================================
|
|
|
|
func parseGeminiRequest(body []byte) (*geminiParsedRequest, error) {
|
|
var envelope geminiEnvelope
|
|
if err := json.Unmarshal(body, &envelope); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
reqBody := body
|
|
if len(envelope.Request) > 0 {
|
|
reqBody = envelope.Request
|
|
}
|
|
|
|
var payload geminiRequestPayload
|
|
if err := json.Unmarshal(reqBody, &payload); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
parsed := &geminiParsedRequest{
|
|
Model: envelope.Model,
|
|
SessionID: payload.SessionID,
|
|
ResponseModalities: append([]string(nil), payload.GenerationConfig.GetResponseModalities()...),
|
|
HasImageConfig: payload.GenerationConfig != nil && len(bytes.TrimSpace(payload.GenerationConfig.ImageConfig)) > 0 && string(bytes.TrimSpace(payload.GenerationConfig.ImageConfig)) != "null",
|
|
}
|
|
if parsed.Model == "" {
|
|
var top map[string]json.RawMessage
|
|
if err := json.Unmarshal(body, &top); err == nil {
|
|
_ = json.Unmarshal(top["model"], &parsed.Model)
|
|
}
|
|
}
|
|
if payload.SystemInstruction != nil {
|
|
parsed.SystemText = collectTextParts(payload.SystemInstruction.Parts)
|
|
}
|
|
for _, content := range payload.Contents {
|
|
turn := geminiConversationTurn{Role: normalizeTurnRole(content.Role)}
|
|
for _, part := range content.Parts {
|
|
switch {
|
|
case part.Thought || part.ThoughtSignature != "":
|
|
parsed.HasUnsupported = true
|
|
case len(part.FunctionCall) > 0 || len(part.FunctionResponse) > 0:
|
|
parsed.HasUnsupported = true
|
|
case part.InlineData != nil:
|
|
turn.Parts = append(turn.Parts, geminiConversationPart{
|
|
Kind: "media",
|
|
MimeType: part.InlineData.MimeType,
|
|
Data: part.InlineData.Data,
|
|
})
|
|
case part.Text != "":
|
|
turn.Parts = append(turn.Parts, geminiConversationPart{
|
|
Kind: "text",
|
|
Text: part.Text,
|
|
})
|
|
}
|
|
}
|
|
if len(turn.Parts) > 0 {
|
|
parsed.Turns = append(parsed.Turns, turn)
|
|
}
|
|
}
|
|
|
|
return parsed, nil
|
|
}
|
|
|
|
func collectTextParts(parts []geminiWirePart) string {
|
|
var texts []string
|
|
for _, part := range parts {
|
|
if part.Text != "" {
|
|
texts = append(texts, part.Text)
|
|
}
|
|
}
|
|
return strings.Join(texts, "\n")
|
|
}
|
|
|
|
func (g *geminiWireGenerationConfig) GetResponseModalities() []string {
|
|
if g == nil {
|
|
return nil
|
|
}
|
|
return g.ResponseModalities
|
|
}
|
|
|
|
func normalizeTurnRole(role string) string {
|
|
if strings.EqualFold(strings.TrimSpace(role), "model") {
|
|
return "model"
|
|
}
|
|
return "user"
|
|
}
|
|
|
|
func decideJSParityRoute(parsed *geminiParsedRequest, body []byte) lsRouteDecision {
|
|
if parsed == nil {
|
|
return lsRouteDecision{Reason: "nil parsed request"}
|
|
}
|
|
if requestHasTools(body) {
|
|
return lsRouteDecision{Reason: "tools are not supported through cascade"}
|
|
}
|
|
if parsed.SessionID == "" {
|
|
return lsRouteDecision{Reason: "missing sessionId"}
|
|
}
|
|
if parsed.HasUnsupported {
|
|
return lsRouteDecision{Reason: "request contains unsupported Gemini parts"}
|
|
}
|
|
if isImageGenerationModelName(parsed.Model) {
|
|
return lsRouteDecision{Reason: "image generation model"}
|
|
}
|
|
if parsed.HasImageConfig {
|
|
return lsRouteDecision{Reason: "request has imageConfig"}
|
|
}
|
|
for _, modality := range parsed.ResponseModalities {
|
|
if strings.EqualFold(strings.TrimSpace(modality), "IMAGE") {
|
|
return lsRouteDecision{Reason: "responseModalities contains IMAGE"}
|
|
}
|
|
}
|
|
if len(parsed.Turns) == 0 {
|
|
return lsRouteDecision{Reason: "empty conversation"}
|
|
}
|
|
return lsRouteDecision{UseLS: true, Reason: "js-parity cascade chat"}
|
|
}
|
|
|
|
func extractPromptAndModel(body []byte) (string, string) {
|
|
var outer map[string]json.RawMessage
|
|
if err := json.Unmarshal(body, &outer); err != nil {
|
|
return "", ""
|
|
}
|
|
var model string
|
|
if m, ok := outer["model"]; ok {
|
|
json.Unmarshal(m, &model)
|
|
}
|
|
if reqRaw, ok := outer["request"]; ok {
|
|
return extractPromptFromGeminiRequest(reqRaw), model
|
|
}
|
|
return extractPromptFromGeminiRequest(body), model
|
|
}
|
|
|
|
func extractPromptFromGeminiRequest(data []byte) string {
|
|
var req struct {
|
|
Contents []struct {
|
|
Parts []struct {
|
|
Text string `json:"text"`
|
|
} `json:"parts"`
|
|
Role string `json:"role"`
|
|
} `json:"contents"`
|
|
SystemInstruction *struct {
|
|
Parts []struct {
|
|
Text string `json:"text"`
|
|
} `json:"parts"`
|
|
} `json:"systemInstruction"`
|
|
}
|
|
if err := json.Unmarshal(data, &req); err != nil {
|
|
return ""
|
|
}
|
|
|
|
var parts []string
|
|
|
|
// Include system instruction if present
|
|
if req.SystemInstruction != nil {
|
|
for _, p := range req.SystemInstruction.Parts {
|
|
if p.Text != "" {
|
|
parts = append(parts, "[System]\n"+p.Text)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Include full conversation history
|
|
for _, c := range req.Contents {
|
|
role := c.Role
|
|
if role == "" {
|
|
role = "user"
|
|
}
|
|
for _, p := range c.Parts {
|
|
if p.Text != "" {
|
|
if role == "model" {
|
|
parts = append(parts, "[Assistant]\n"+p.Text)
|
|
} else {
|
|
parts = append(parts, "[User]\n"+p.Text)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(parts) == 0 {
|
|
return ""
|
|
}
|
|
|
|
// If only one part and no system instruction, return raw text (simple case)
|
|
if len(parts) == 1 && req.SystemInstruction == nil {
|
|
text := parts[0]
|
|
// Strip the [User]\n prefix for simple single-message case
|
|
if strings.HasPrefix(text, "[User]\n") {
|
|
return strings.TrimPrefix(text, "[User]\n")
|
|
}
|
|
return text
|
|
}
|
|
|
|
return strings.Join(parts, "\n\n")
|
|
}
|
|
|
|
func buildLSInputFromTurn(turn geminiConversationTurn, contextPrefix string) ([]map[string]any, []map[string]any, error) {
|
|
items := make([]map[string]any, 0, len(turn.Parts)+1)
|
|
media := make([]map[string]any, 0)
|
|
if strings.TrimSpace(contextPrefix) != "" {
|
|
items = append(items, map[string]any{"text": contextPrefix})
|
|
}
|
|
for _, part := range turn.Parts {
|
|
switch part.Kind {
|
|
case "text":
|
|
if part.Text != "" {
|
|
items = append(items, map[string]any{"text": part.Text})
|
|
}
|
|
case "media":
|
|
decoded, err := base64.StdEncoding.DecodeString(part.Data)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("decode inlineData: %w", err)
|
|
}
|
|
media = append(media, map[string]any{
|
|
"mimeType": part.MimeType,
|
|
"inlineData": decoded,
|
|
})
|
|
}
|
|
}
|
|
return items, media, nil
|
|
}
|
|
|
|
func renderConversationContext(systemText string, turns []geminiConversationTurn) string {
|
|
var parts []string
|
|
if strings.TrimSpace(systemText) != "" {
|
|
parts = append(parts, "[System]\n"+strings.TrimSpace(systemText))
|
|
}
|
|
for _, turn := range turns {
|
|
var rendered []string
|
|
for _, part := range turn.Parts {
|
|
switch part.Kind {
|
|
case "text":
|
|
if strings.TrimSpace(part.Text) != "" {
|
|
rendered = append(rendered, part.Text)
|
|
}
|
|
case "media":
|
|
label := "attachment"
|
|
switch {
|
|
case strings.HasPrefix(part.MimeType, "image/"):
|
|
label = "image attachment"
|
|
case strings.HasPrefix(part.MimeType, "video/"):
|
|
label = "video attachment"
|
|
case strings.HasPrefix(part.MimeType, "audio/"):
|
|
label = "audio attachment"
|
|
}
|
|
rendered = append(rendered, fmt.Sprintf("[%s: %s]", label, part.MimeType))
|
|
}
|
|
}
|
|
if len(rendered) == 0 {
|
|
continue
|
|
}
|
|
roleLabel := "User"
|
|
if turn.Role == "model" {
|
|
roleLabel = "Assistant"
|
|
}
|
|
parts = append(parts, fmt.Sprintf("[%s]\n%s", roleLabel, strings.Join(rendered, "\n")))
|
|
}
|
|
return strings.Join(parts, "\n\n")
|
|
}
|
|
|
|
func buildCascadeConfig(model string) map[string]any {
|
|
normalizedModel := normalizeRequestedModelName(model)
|
|
if normalizedModel == "" {
|
|
return nil
|
|
}
|
|
modelEnum := resolveModelEnum(normalizedModel)
|
|
|
|
return map[string]any{
|
|
"plannerConfig": map[string]any{
|
|
"requestedModel": map[string]any{
|
|
"model": modelEnum,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func buildLSRequestMetadata() map[string]any {
|
|
return map[string]any{
|
|
"ideName": "antigravity",
|
|
"ideVersion": "1.107.0",
|
|
}
|
|
}
|
|
|
|
func appendModelTurn(turns []geminiConversationTurn, modelText string) []geminiConversationTurn {
|
|
if strings.TrimSpace(modelText) == "" {
|
|
return turns
|
|
}
|
|
return append(turns, geminiConversationTurn{
|
|
Role: "model",
|
|
Parts: []geminiConversationPart{{
|
|
Kind: "text",
|
|
Text: modelText,
|
|
}},
|
|
})
|
|
}
|
|
|
|
func cloneConversationTurns(src []geminiConversationTurn) []geminiConversationTurn {
|
|
out := make([]geminiConversationTurn, 0, len(src))
|
|
for _, turn := range src {
|
|
copied := geminiConversationTurn{
|
|
Role: turn.Role,
|
|
Parts: append([]geminiConversationPart(nil), turn.Parts...),
|
|
}
|
|
out = append(out, copied)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func conversationPrefixEqual(full, prefix []geminiConversationTurn) bool {
|
|
if len(prefix) > len(full) {
|
|
return false
|
|
}
|
|
for i := range prefix {
|
|
if prefix[i].Role != full[i].Role {
|
|
return false
|
|
}
|
|
if len(prefix[i].Parts) != len(full[i].Parts) {
|
|
return false
|
|
}
|
|
for j := range prefix[i].Parts {
|
|
if prefix[i].Parts[j] != full[i].Parts[j] {
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// ResolveModelEnumPublic is the exported version of resolveModelEnum for testing.
|
|
func ResolveModelEnumPublic(model string) int {
|
|
return resolveModelEnum(model)
|
|
}
|
|
|
|
// resolveModelEnum maps a Gemini/Claude model name to its proto enum number.
|
|
// Priority: dynamic mapping (from LS) > static fallback.
|
|
// The LS uses MODEL_PLACEHOLDER_Mn enum values (1000+n) that are dynamically
|
|
// assigned by the server — only these are guaranteed to work.
|
|
func resolveModelEnum(model string) int {
|
|
model = normalizeRequestedModelName(model)
|
|
|
|
// 1. Try dynamic mapping first (populated by RefreshModelMapping from LS)
|
|
dynamicModelMapMu.RLock()
|
|
// Exact match
|
|
if v, ok := dynamicModelMap[model]; ok {
|
|
dynamicModelMapMu.RUnlock()
|
|
return v
|
|
}
|
|
// Fuzzy match: normalized label vs model name
|
|
for label, v := range dynamicModelMap {
|
|
if labelMatchesModel(label, model) {
|
|
dynamicModelMapMu.RUnlock()
|
|
return v
|
|
}
|
|
}
|
|
// Prefix match in dynamic map
|
|
for label, v := range dynamicModelMap {
|
|
normalized := normalizeLabel(label)
|
|
if strings.HasPrefix(model, normalized) || strings.HasPrefix(normalized, model) {
|
|
dynamicModelMapMu.RUnlock()
|
|
return v
|
|
}
|
|
}
|
|
dynamicModelMapMu.RUnlock()
|
|
|
|
// 2. Known working placeholders (verified on Mac with LS v1.107.0)
|
|
// These map display labels to MODEL_PLACEHOLDER_Mn enum values
|
|
knownPlaceholders := map[string]int{
|
|
"gemini-3-flash": 1047,
|
|
"gemini-3.1-pro-high": 1037,
|
|
"gemini-3.1-pro-low": 1036,
|
|
"claude-sonnet-4-6-thinking": 1035,
|
|
"claude-opus-4-6-thinking": 1026,
|
|
"gpt-oss-120b-medium": 342,
|
|
}
|
|
if v, ok := knownPlaceholders[model]; ok {
|
|
return v
|
|
}
|
|
// Fuzzy match known placeholders
|
|
modelLower := strings.ToLower(model)
|
|
for k, v := range knownPlaceholders {
|
|
if strings.Contains(modelLower, strings.ToLower(k)) || strings.Contains(strings.ToLower(k), modelLower) {
|
|
return v
|
|
}
|
|
}
|
|
|
|
// 3. Family-based fallback from known placeholders
|
|
for k, v := range knownPlaceholders {
|
|
if strings.Contains(modelLower, "claude") && strings.Contains(k, "claude") {
|
|
return v
|
|
}
|
|
if strings.Contains(modelLower, "gemini") && strings.Contains(k, "gemini") {
|
|
return v
|
|
}
|
|
if strings.Contains(modelLower, "gpt") && strings.Contains(k, "gpt") {
|
|
return v
|
|
}
|
|
}
|
|
|
|
// 4. Also check dynamic map if available
|
|
dynamicModelMapMu.RLock()
|
|
defer dynamicModelMapMu.RUnlock()
|
|
for label, v := range dynamicModelMap {
|
|
labelLower := strings.ToLower(normalizeLabel(label))
|
|
// Same family: "claude" matches "claude-*", "gemini" matches "gemini-*"
|
|
if strings.Contains(modelLower, "claude") && strings.Contains(labelLower, "claude") {
|
|
return v
|
|
}
|
|
if strings.Contains(modelLower, "gemini") && strings.Contains(labelLower, "gemini") {
|
|
return v
|
|
}
|
|
if strings.Contains(modelLower, "gpt") && strings.Contains(labelLower, "gpt") {
|
|
return v
|
|
}
|
|
}
|
|
|
|
// Last resort: return first available model from dynamic map
|
|
for _, v := range dynamicModelMap {
|
|
return v
|
|
}
|
|
|
|
// No dynamic mapping at all (LS not started yet?) — use gemini-2.5-flash static
|
|
return 312
|
|
}
|
|
|
|
// labelMatchesModel does fuzzy matching between LS display label and sub2api model name.
|
|
// e.g. "Gemini 3 Flash" matches "gemini-3-flash", "Claude Sonnet 4.6 (Thinking)" matches "claude-sonnet-4-6-thinking"
|
|
func labelMatchesModel(label, model string) bool {
|
|
normalize := func(s string) string {
|
|
s = strings.ToLower(s)
|
|
s = strings.ReplaceAll(s, " ", "-")
|
|
s = strings.ReplaceAll(s, ".", "-")
|
|
s = strings.ReplaceAll(s, "(", "")
|
|
s = strings.ReplaceAll(s, ")", "")
|
|
s = strings.ReplaceAll(s, "--", "-")
|
|
return strings.TrimRight(s, "-")
|
|
}
|
|
return normalize(label) == normalize(model)
|
|
}
|
|
|
|
// Dynamic model mapping — refreshed from LS at startup
|
|
var (
|
|
dynamicModelMapMu sync.RWMutex
|
|
dynamicModelMap = map[string]int{} // label -> enum value
|
|
)
|
|
|
|
// HasDynamicModelMappingPublic is exported for testing.
|
|
func HasDynamicModelMappingPublic() bool {
|
|
return hasDynamicModelMapping()
|
|
}
|
|
|
|
// hasDynamicModelMapping returns true if at least one model has been loaded from the LS.
|
|
func hasDynamicModelMapping() bool {
|
|
dynamicModelMapMu.RLock()
|
|
defer dynamicModelMapMu.RUnlock()
|
|
return len(dynamicModelMap) > 0
|
|
}
|
|
|
|
// RefreshModelMapping queries the LS for available models and builds the mapping.
|
|
// Called automatically when an LS instance starts.
|
|
func RefreshModelMapping(inst *Instance) bool {
|
|
if inst == nil {
|
|
return false
|
|
}
|
|
startedAt := time.Now()
|
|
ctx, cancel := context.WithTimeout(context.Background(), lsModelConfigTimeout)
|
|
defer cancel()
|
|
|
|
resp, err := inst.CallUnaryJSON(ctx, LSService, "GetCascadeModelConfigData", map[string]any{})
|
|
if err != nil {
|
|
inst.SetModelMappingReady(false)
|
|
slog.Warn("[LS-POOL] Failed to get model config",
|
|
"account", inst.AccountID,
|
|
"replica", inst.Replica,
|
|
"address", inst.Address,
|
|
"elapsed", time.Since(startedAt).Truncate(time.Millisecond),
|
|
"err", err)
|
|
return false
|
|
}
|
|
|
|
var data struct {
|
|
ClientModelConfigs []struct {
|
|
Label string `json:"label"`
|
|
ModelOrAlias map[string]any `json:"modelOrAlias"`
|
|
} `json:"clientModelConfigs"`
|
|
}
|
|
if err := json.Unmarshal(resp, &data); err != nil {
|
|
inst.SetModelMappingReady(false)
|
|
return false
|
|
}
|
|
|
|
newMap := make(map[string]int)
|
|
for _, cfg := range data.ClientModelConfigs {
|
|
label := cfg.Label
|
|
if label == "" {
|
|
continue
|
|
}
|
|
// modelOrAlias is {"model": "MODEL_PLACEHOLDER_M37"} in JSON
|
|
modelStr, _ := cfg.ModelOrAlias["model"].(string)
|
|
if modelStr == "" {
|
|
continue
|
|
}
|
|
// Parse "MODEL_PLACEHOLDER_M37" → 1037
|
|
enumVal := parseModelEnumString(modelStr)
|
|
if enumVal > 0 {
|
|
// Store both the display label and a normalized form
|
|
newMap[label] = enumVal
|
|
// Also store kebab-case version: "Gemini 3 Flash" → "gemini-3-flash"
|
|
normalized := normalizeLabel(label)
|
|
if normalized != "" {
|
|
newMap[normalized] = enumVal
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(newMap) > 0 {
|
|
dynamicModelMapMu.Lock()
|
|
dynamicModelMap = newMap
|
|
dynamicModelMapMu.Unlock()
|
|
inst.SetModelMappingReady(true)
|
|
slog.Info("[LS-POOL] Model mapping refreshed",
|
|
"account", inst.AccountID,
|
|
"replica", inst.Replica,
|
|
"address", inst.Address,
|
|
"count", len(newMap)/2,
|
|
"elapsed", time.Since(startedAt).Truncate(time.Millisecond))
|
|
return true
|
|
}
|
|
inst.SetModelMappingReady(false)
|
|
return false
|
|
}
|
|
|
|
func parseModelEnumString(s string) int {
|
|
// Named enums
|
|
named := map[string]int{
|
|
"MODEL_CLAUDE_4_SONNET": 281,
|
|
"MODEL_CLAUDE_4_SONNET_THINKING": 282,
|
|
"MODEL_CLAUDE_4_OPUS": 290,
|
|
"MODEL_CLAUDE_4_OPUS_THINKING": 291,
|
|
"MODEL_CLAUDE_4_5_SONNET": 333,
|
|
"MODEL_CLAUDE_4_5_SONNET_THINKING": 334,
|
|
"MODEL_CLAUDE_4_5_HAIKU": 340,
|
|
"MODEL_CLAUDE_4_5_HAIKU_THINKING": 341,
|
|
"MODEL_OPENAI_GPT_OSS_120B_MEDIUM": 342,
|
|
"MODEL_GOOGLE_GEMINI_2_5_FLASH": 312,
|
|
"MODEL_GOOGLE_GEMINI_2_5_FLASH_THINKING": 313,
|
|
"MODEL_GOOGLE_GEMINI_2_5_FLASH_LITE": 330,
|
|
"MODEL_GOOGLE_GEMINI_2_5_PRO": 246,
|
|
}
|
|
if v, ok := named[s]; ok {
|
|
return v
|
|
}
|
|
// "MODEL_PLACEHOLDER_M37" → 1037
|
|
if strings.HasPrefix(s, "MODEL_PLACEHOLDER_M") {
|
|
numStr := strings.TrimPrefix(s, "MODEL_PLACEHOLDER_M")
|
|
n, err := strconv.Atoi(numStr)
|
|
if err == nil {
|
|
return 1000 + n
|
|
}
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func normalizeLabel(label string) string {
|
|
s := strings.ToLower(label)
|
|
s = strings.ReplaceAll(s, " ", "-")
|
|
s = strings.ReplaceAll(s, ".", "-")
|
|
s = strings.ReplaceAll(s, "(", "")
|
|
s = strings.ReplaceAll(s, ")", "")
|
|
s = strings.ReplaceAll(s, "--", "-")
|
|
return strings.TrimRight(s, "-")
|
|
}
|
|
|
|
func normalizeRequestedModelName(model string) string {
|
|
normalized := strings.ToLower(strings.TrimSpace(model))
|
|
normalized = strings.TrimPrefix(normalized, "models/")
|
|
return normalized
|
|
}
|
|
|
|
func isGeminiPlannerModel(model string) bool {
|
|
return strings.Contains(normalizeRequestedModelName(model), "gemini")
|
|
}
|
|
|
|
func systemTextCompatible(stored, current string) bool {
|
|
stored = strings.TrimSpace(stored)
|
|
current = strings.TrimSpace(current)
|
|
return current == "" || current == stored
|
|
}
|
|
|
|
// ============================================================
|
|
// Helpers
|
|
// ============================================================
|
|
|
|
func buildSessionCacheKey(accountID int64, namespace, sessionID string) string {
|
|
return fmt.Sprintf("%d:%s:%s", accountID, namespace, sessionID)
|
|
}
|
|
|
|
func userNamespace(req *http.Request) string {
|
|
if req == nil {
|
|
return "anon"
|
|
}
|
|
for _, value := range []string{
|
|
req.Header.Get(userNamespaceHeader),
|
|
req.Header.Get("X-Api-Key"),
|
|
req.Header.Get("X-Goog-Api-Key"),
|
|
req.Header.Get("Authorization"),
|
|
} {
|
|
if strings.TrimSpace(value) != "" {
|
|
sum := sha256.Sum256([]byte(value))
|
|
return fmt.Sprintf("%x", sum[:8])
|
|
}
|
|
}
|
|
return "anon"
|
|
}
|
|
|
|
func (u *LSPoolUpstream) getSessionState(key string) *cascadeSessionState {
|
|
u.sessionMu.Lock()
|
|
defer u.sessionMu.Unlock()
|
|
u.pruneExpiredSessionsLocked()
|
|
state := u.sessions[key]
|
|
if state == nil {
|
|
return nil
|
|
}
|
|
cloned := &cascadeSessionState{
|
|
CascadeID: state.CascadeID,
|
|
SystemText: state.SystemText,
|
|
History: cloneConversationTurns(state.History),
|
|
UpdatedAt: state.UpdatedAt,
|
|
}
|
|
return cloned
|
|
}
|
|
|
|
func (u *LSPoolUpstream) putSessionState(key string, state *cascadeSessionState) {
|
|
if state == nil {
|
|
return
|
|
}
|
|
u.sessionMu.Lock()
|
|
defer u.sessionMu.Unlock()
|
|
u.pruneExpiredSessionsLocked()
|
|
u.sessions[key] = &cascadeSessionState{
|
|
CascadeID: state.CascadeID,
|
|
SystemText: state.SystemText,
|
|
History: cloneConversationTurns(state.History),
|
|
UpdatedAt: state.UpdatedAt,
|
|
}
|
|
}
|
|
|
|
func (u *LSPoolUpstream) pruneExpiredSessionsLocked() {
|
|
now := time.Now()
|
|
for key, state := range u.sessions {
|
|
if state == nil || now.Sub(state.UpdatedAt) > sessionStateTTL {
|
|
delete(u.sessions, key)
|
|
}
|
|
}
|
|
}
|
|
|
|
func isStreamGenerate(path string) bool {
|
|
return strings.Contains(path, "streamGenerateContent")
|
|
}
|
|
|
|
// isQuotaExhaustedError detects 429 QUOTA_EXHAUSTED errors from LS cascade trajectory.
|
|
// When detected, the caller should fall back to direct HTTP so the gateway can
|
|
// inject enabledCreditTypes for AI Credits retry.
|
|
func isQuotaExhaustedError(msg string) bool {
|
|
lower := strings.ToLower(msg)
|
|
return (strings.Contains(lower, "resource_exhausted") || strings.Contains(lower, "quota_exhausted")) &&
|
|
(strings.Contains(lower, "429") || strings.Contains(lower, "exhausted your capacity"))
|
|
}
|
|
|
|
func isImageGenerationModelName(model string) bool {
|
|
modelLower := normalizeRequestedModelName(model)
|
|
return modelLower == "gemini-3.1-flash-image" ||
|
|
modelLower == "gemini-3.1-flash-image-preview" ||
|
|
strings.HasPrefix(modelLower, "gemini-3.1-flash-image-") ||
|
|
modelLower == "gemini-3-pro-image" ||
|
|
modelLower == "gemini-3-pro-image-preview" ||
|
|
strings.HasPrefix(modelLower, "gemini-3-pro-image-") ||
|
|
modelLower == "gemini-2.5-flash-image" ||
|
|
modelLower == "gemini-2.5-flash-image-preview" ||
|
|
strings.HasPrefix(modelLower, "gemini-2.5-flash-image-")
|
|
}
|
|
|
|
// requestHasTools checks if the Gemini request body contains tools/function declarations.
|
|
// These are not supported through the Cascade path and must use direct HTTP.
|
|
func requestHasTools(body []byte) bool {
|
|
// Check both the wrapped format {"request": {"tools": [...]}} and direct {"tools": [...]}
|
|
var outer map[string]json.RawMessage
|
|
if err := json.Unmarshal(body, &outer); err != nil {
|
|
return false
|
|
}
|
|
|
|
// Check in wrapped request
|
|
if reqRaw, ok := outer["request"]; ok {
|
|
var inner map[string]json.RawMessage
|
|
if json.Unmarshal(reqRaw, &inner) == nil {
|
|
if tools, ok := inner["tools"]; ok && len(tools) > 4 { // > "[]" or "null"
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check at top level
|
|
if tools, ok := outer["tools"]; ok && len(tools) > 4 {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func snapshotRequestBody(req *http.Request) ([]byte, error) {
|
|
if req.Body == nil {
|
|
return nil, nil
|
|
}
|
|
body, err := io.ReadAll(req.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Body.Close()
|
|
req.Body = io.NopCloser(bytes.NewReader(body))
|
|
return body, nil
|
|
}
|
|
|
|
// unused but needed for compilation
|
|
var _ sync.Mutex
|