1355 lines
39 KiB
Go
1355 lines
39 KiB
Go
// LocalLS provides a gRPC client for the local Windsurf LanguageServerService.
|
||
//
|
||
// The correct chat flow routes through the local LS binary (which handles
|
||
// all auth/session management internally) rather than calling the upstream
|
||
// ApiServerService directly. The Cascade flow:
|
||
//
|
||
// 1. InitializeCascadePanelState — one-shot per LS session
|
||
// 2. AddTrackedWorkspace — one-shot per LS session
|
||
// 3. UpdateWorkspaceTrust — one-shot per LS session
|
||
// 4. StartCascade → cascade_id
|
||
// 5. SendUserCascadeMessage — send prompt + model config
|
||
// 6. GetCascadeTrajectorySteps — poll until trajectory status is IDLE (2)
|
||
package windsurf
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"crypto/sha256"
|
||
"crypto/tls"
|
||
"encoding/binary"
|
||
"encoding/hex"
|
||
"fmt"
|
||
"io"
|
||
"log/slog"
|
||
"net"
|
||
"net/http"
|
||
"net/url"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"golang.org/x/net/http2"
|
||
"google.golang.org/protobuf/proto"
|
||
|
||
pb "github.com/Wei-Shaw/sub2api/internal/gen/language_server_pb"
|
||
)
|
||
|
||
const (
|
||
StartCascadeRPC = "/exa.language_server_pb.LanguageServerService/StartCascade"
|
||
InitPanelStateRPC = "/exa.language_server_pb.LanguageServerService/InitializeCascadePanelState"
|
||
AddTrackedWorkspaceRPC = "/exa.language_server_pb.LanguageServerService/AddTrackedWorkspace"
|
||
UpdateWorkspaceTrustRPC = "/exa.language_server_pb.LanguageServerService/UpdateWorkspaceTrust"
|
||
SendUserCascadeMessageRPC = "/exa.language_server_pb.LanguageServerService/SendUserCascadeMessage"
|
||
GetCascadeTrajectoryStepsRPC = "/exa.language_server_pb.LanguageServerService/GetCascadeTrajectorySteps"
|
||
GetCascadeTrajectoryStatusRPC = "/exa.language_server_pb.LanguageServerService/GetCascadeTrajectory"
|
||
GetCascadeModelConfigsRPC = "/exa.language_server_pb.LanguageServerService/GetCascadeModelConfigs"
|
||
)
|
||
|
||
// cascadeModelCapsCacheEntry 是单个 API key 下模型能力的缓存条目。
|
||
type cascadeModelCapsCacheEntry struct {
|
||
SupportsImages map[string]bool
|
||
FetchedAt time.Time
|
||
}
|
||
|
||
// cascadeModelCapsTTL 能力缓存 TTL(5 分钟)。
|
||
const cascadeModelCapsTTL = 5 * time.Minute
|
||
|
||
// LocalLSClient talks to the local Windsurf LanguageServerService via h2c (plain HTTP/2 over TCP).
|
||
type LocalLSClient struct {
|
||
BaseURL string
|
||
CSRFToken string
|
||
HTTP *http.Client
|
||
SessionID string
|
||
Warmed bool
|
||
// TrackedWorkspace is optional. When empty, the LS is treated as having no
|
||
// server-side repository context and relies on caller-provided tool results.
|
||
TrackedWorkspace string
|
||
mu sync.Mutex
|
||
|
||
// 模型能力缓存(per-API-key hash),供 Cascade 图像 gate 使用。
|
||
modelCapsMu sync.Mutex
|
||
modelCapsCache map[string]cascadeModelCapsCacheEntry
|
||
}
|
||
|
||
// NewLocalLSClient builds a client for the local LS at the given port.
|
||
func NewLocalLSClient(port int, csrfToken string) *LocalLSClient {
|
||
h2cTransport := &http2.Transport{
|
||
AllowHTTP: true,
|
||
DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
|
||
return (&net.Dialer{Timeout: 5 * time.Second}).DialContext(ctx, network, addr)
|
||
},
|
||
}
|
||
return &LocalLSClient{
|
||
BaseURL: fmt.Sprintf("http://localhost:%d", port),
|
||
CSRFToken: csrfToken,
|
||
SessionID: generateUUID(),
|
||
TrackedWorkspace: "",
|
||
HTTP: &http.Client{
|
||
Transport: h2cTransport,
|
||
Timeout: 60 * time.Second,
|
||
},
|
||
}
|
||
}
|
||
|
||
// WarmupCascade runs the one-shot panel init sequence required before StartCascade.
|
||
// Idempotent — skip if already warmed.
|
||
func (l *LocalLSClient) WarmupCascade(ctx context.Context, token string) error {
|
||
return l.warmupCascade(ctx, token, false)
|
||
}
|
||
|
||
// ForceWarmupCascade resets session state and re-runs warmup.
|
||
func (l *LocalLSClient) ForceWarmupCascade(ctx context.Context, token string) error {
|
||
return l.warmupCascade(ctx, token, true)
|
||
}
|
||
|
||
func (l *LocalLSClient) warmupCascade(ctx context.Context, token string, force bool) error {
|
||
l.mu.Lock()
|
||
defer l.mu.Unlock()
|
||
|
||
if force {
|
||
l.Warmed = false
|
||
l.SessionID = generateUUID()
|
||
}
|
||
if l.Warmed {
|
||
return nil
|
||
}
|
||
if l.SessionID == "" {
|
||
l.SessionID = generateUUID()
|
||
}
|
||
|
||
var firstErr error
|
||
|
||
// InitializeCascadePanelState: F1=metadata, F3=workspace_trusted (bool, true)
|
||
initReq := encodeBytesField(1, buildMetadata(token, l.SessionID))
|
||
initReq = append(initReq, encodeVarintField(3, 1)...)
|
||
if err := l.grpcUnary(ctx, InitPanelStateRPC, initReq); err != nil {
|
||
firstErr = err
|
||
}
|
||
|
||
// AddTrackedWorkspace is optional. Default Windsurf mode should not pretend
|
||
// to have a mounted repository when the server does not actually have one.
|
||
if strings.TrimSpace(l.TrackedWorkspace) != "" {
|
||
addWsReq := encodeStringField(1, l.TrackedWorkspace)
|
||
_ = l.grpcUnary(ctx, AddTrackedWorkspaceRPC, addWsReq)
|
||
}
|
||
|
||
// UpdateWorkspaceTrust: F1=metadata, F2=workspace_trusted (bool, true)
|
||
trustReq := encodeBytesField(1, buildMetadata(token, l.SessionID))
|
||
trustReq = append(trustReq, encodeVarintField(2, 1)...)
|
||
if err := l.grpcUnary(ctx, UpdateWorkspaceTrustRPC, trustReq); err != nil && firstErr == nil {
|
||
firstErr = err
|
||
}
|
||
|
||
// Only mark warmed on success (unlike the old code which always set true)
|
||
if firstErr == nil {
|
||
l.Warmed = true
|
||
}
|
||
return firstErr
|
||
}
|
||
|
||
// StartCascade calls StartCascade and returns the cascade_id.
|
||
// Retries once on panel-state-not-found.
|
||
func (l *LocalLSClient) StartCascade(ctx context.Context, token string) (string, error) {
|
||
doStart := func() (string, error) {
|
||
body := encodeBytesField(1, buildMetadata(token, l.SessionID))
|
||
resp, err := l.grpcUnaryRaw(ctx, StartCascadeRPC, body)
|
||
if err != nil {
|
||
return "", fmt.Errorf("StartCascade: %w", err)
|
||
}
|
||
cascadeID, err := parseStringField1(resp)
|
||
if err != nil {
|
||
return "", fmt.Errorf("StartCascade parse: %w", err)
|
||
}
|
||
if cascadeID == "" {
|
||
return "", fmt.Errorf("StartCascade: empty cascade_id (hex=%x)", resp)
|
||
}
|
||
return cascadeID, nil
|
||
}
|
||
|
||
cascadeID, err := doStart()
|
||
if err != nil && isPanelStateNotFound(err) {
|
||
_ = l.ForceWarmupCascade(ctx, token)
|
||
return doStart()
|
||
}
|
||
return cascadeID, err
|
||
}
|
||
|
||
// SendUserCascadeMessage sends a message into an existing cascade session.
|
||
// Returns the (possibly new) cascadeID — it changes if panel-state retry triggers a new StartCascade.
|
||
// toolPreamble, if non-empty, is injected into the tool_calling_section override.
|
||
// SendUserCascadeMessage sends a user chat message to Cascade.
|
||
// Returns the (possibly new) cascadeID — it changes if panel-state retry triggers a new StartCascade.
|
||
// toolPreamble, if non-empty, is injected into the tool_calling_section override.
|
||
// images(可选)作为 SendUserCascadeMessageRequest.images (field 6) 追加到 proto wire。
|
||
//
|
||
// allowRecreate 控制 panel-state-not-found 时是否内部静默重建 cascade:
|
||
// - true:ForceWarmup + StartCascade + 用 SAME text 再发一次。调用方须保证
|
||
// text 已包含完整历史(否则新 cascade 无状态 + text 无历史 = 上下文丢失)。
|
||
// - false:直接返回错误,让调用方重建含完整历史的 text 后再调。
|
||
//
|
||
// 经验值:StreamCascadeChat 内当 reuseCascadeID 为空(本地 StartCascade 的流程,
|
||
// text 已是 full-history)时传 true;reuse 场景(text 可能仅含最后一条消息)传 false。
|
||
func (l *LocalLSClient) SendUserCascadeMessage(ctx context.Context, token, cascadeID, text, modelUID, toolPreamble string, modelEnumHint int, images []CascadeImage, allowRecreate bool) (string, error) {
|
||
modelEnum := resolveModelEnum(modelUID)
|
||
if modelEnum == 0 && modelEnumHint > 0 {
|
||
modelEnum = modelEnumHint
|
||
}
|
||
|
||
doSend := func(cid string) error {
|
||
body := encodeStringField(1, cid)
|
||
body = append(body, encodeBytesField(2, encodeStringField(1, text))...)
|
||
body = append(body, encodeBytesField(3, buildMetadata(token, l.SessionID))...)
|
||
body = append(body, encodeBytesField(5, buildCascadeConfig(modelUID, modelEnum, toolPreamble))...)
|
||
// field 6: repeated CodeiumImage images(逆向自 Windsurf.app chat-client)
|
||
body = appendSendUserCascadeImages(body, images)
|
||
return l.grpcUnary(ctx, SendUserCascadeMessageRPC, body)
|
||
}
|
||
|
||
if err := doSend(cascadeID); err != nil {
|
||
if !isPanelStateNotFound(err) {
|
||
return "", err
|
||
}
|
||
if !allowRecreate {
|
||
// reuse 场景:不要静默用 last-message-only 的 text 去灌新 cascade。
|
||
// 返回错误,让 chatCascade 用 full-history text 重建整个调用。
|
||
return "", err
|
||
}
|
||
_ = l.ForceWarmupCascade(ctx, token)
|
||
newCascadeID, startErr := l.StartCascade(ctx, token)
|
||
if startErr != nil {
|
||
return "", startErr
|
||
}
|
||
if err := doSend(newCascadeID); err != nil {
|
||
return "", err
|
||
}
|
||
return newCascadeID, nil
|
||
}
|
||
return cascadeID, nil
|
||
}
|
||
|
||
// buildMetadata builds the full Metadata proto for local LS calls, aligned with WindsurfAPI.
|
||
func buildMetadata(token, sessionID string) []byte {
|
||
if sessionID == "" {
|
||
sessionID = generateUUID()
|
||
}
|
||
var meta []byte
|
||
meta = append(meta, encodeStringField(1, AppName)...) // ide_name
|
||
meta = append(meta, encodeStringField(2, ExtensionVersion)...) // extension_version
|
||
meta = append(meta, encodeStringField(3, token)...) // api_key
|
||
meta = append(meta, encodeStringField(4, "en")...) // locale
|
||
meta = append(meta, encodeStringField(5, RuntimeOS())...) // os
|
||
meta = append(meta, encodeStringField(7, IDEVersion)...) // ide_version
|
||
meta = append(meta, encodeStringField(8, HardwareArch())...) // hardware
|
||
meta = append(meta, encodeVarintField(9, uint64(time.Now().UnixMilli()))...) // request_id
|
||
meta = append(meta, encodeStringField(10, sessionID)...) // session_id
|
||
meta = append(meta, encodeStringField(12, AppName)...) // extension_name
|
||
return meta
|
||
}
|
||
|
||
// buildSectionOverride builds a SectionOverrideConfig { mode=OVERRIDE(1), content=text }.
|
||
func buildSectionOverride(content string) []byte {
|
||
var out []byte
|
||
out = append(out, encodeVarintField(1, 1)...) // SECTION_OVERRIDE_MODE_OVERRIDE
|
||
out = append(out, encodeStringField(2, content)...)
|
||
return out
|
||
}
|
||
|
||
// buildCascadeConfig builds a CascadeConfig for the given model UID and enum.
|
||
// Uses NO_TOOL planner mode (3) with section overrides for pure conversational responses.
|
||
//
|
||
// Key insight (2026-04-12): NO_TOOL mode SUPPRESSES field 10 (tool_calling_section) —
|
||
// it is injected but never rendered to the model. Tool definitions MUST go into
|
||
// field 12 (additional_instructions_section) which IS rendered regardless of planner mode.
|
||
// Field 10 is kept as belt-and-suspenders.
|
||
func buildCascadeConfig(modelUID string, modelEnum int, toolPreamble string) []byte {
|
||
var convParts []byte
|
||
convParts = append(convParts, encodeVarintField(4, 3)...) // planner_mode=NO_TOOL(3)
|
||
|
||
const toolReinforcement = "\n\nThe functions listed above are available and callable. " +
|
||
"When the user's request can be answered by calling a function, emit a <tool_call> block as described. " +
|
||
"Use this exact format: <tool_call>{\"name\":\"...\",\"arguments\":{...}}</tool_call>"
|
||
|
||
if toolPreamble != "" {
|
||
// Primary: field 12 (additional_instructions_section) — always rendered in NO_TOOL mode
|
||
convParts = append(convParts, encodeBytesField(12, buildSectionOverride(toolPreamble+toolReinforcement))...)
|
||
// Belt-and-suspenders: field 10 (tool_calling_section)
|
||
convParts = append(convParts, encodeBytesField(10, buildSectionOverride(toolPreamble))...)
|
||
// field 13 (communication_section)
|
||
convParts = append(convParts, encodeBytesField(13, buildSectionOverride(
|
||
"You are accessed via API. Respond in the same language as the user. "+
|
||
"Use the functions above when relevant."))...)
|
||
} else {
|
||
// field 10: suppress built-in tool list
|
||
convParts = append(convParts, encodeBytesField(10, buildSectionOverride("No tools are available."))...)
|
||
// field 12: reinforce direct-answer mode
|
||
convParts = append(convParts, encodeBytesField(12, buildSectionOverride(
|
||
"You have no tools, no file access, and no command execution. "+
|
||
"Answer all questions directly using your knowledge. "+
|
||
"Never pretend to create files or check directories."))...)
|
||
// field 11 (code_changes_section): suppress IDE-specific boilerplate
|
||
convParts = append(convParts, encodeBytesField(11, buildSectionOverride(""))...)
|
||
// field 13 (communication_section)
|
||
convParts = append(convParts, encodeBytesField(13, buildSectionOverride(
|
||
"You are accessed via API. Answer directly. "+
|
||
"Respond in the same language as the user."))...)
|
||
}
|
||
|
||
// CortexPlannerConfig
|
||
var plannerParts []byte
|
||
plannerParts = append(plannerParts, encodeBytesField(2, convParts)...) // conversational=2
|
||
|
||
if modelUID != "" {
|
||
plannerParts = append(plannerParts, encodeStringField(35, modelUID)...)
|
||
plannerParts = append(plannerParts, encodeStringField(34, modelUID)...)
|
||
}
|
||
if modelEnum > 0 {
|
||
plannerParts = append(plannerParts, encodeBytesField(15, encodeVarintField(1, uint64(modelEnum)))...)
|
||
plannerParts = append(plannerParts, encodeVarintField(1, uint64(modelEnum))...)
|
||
}
|
||
|
||
// max_output_tokens (field 6) = 32768 — prevents long response truncation
|
||
plannerParts = append(plannerParts, encodeVarintField(6, 32768)...)
|
||
|
||
// BrainConfig: F1=enabled=true, F6=update_strategy{dynamic_update{}}
|
||
var brainParts []byte
|
||
brainParts = append(brainParts, encodeVarintField(1, 1)...)
|
||
brainParts = append(brainParts, encodeBytesField(6, encodeBytesField(6, nil))...)
|
||
|
||
// memory_config (field 5): {enabled=false} — prevent LS injecting user's stored memories
|
||
memoryConfig := encodeVarintField(1, 0) // bool enabled = false
|
||
|
||
var cfg []byte
|
||
cfg = append(cfg, encodeBytesField(1, plannerParts)...)
|
||
cfg = append(cfg, encodeBytesField(5, memoryConfig)...)
|
||
cfg = append(cfg, encodeBytesField(7, brainParts)...)
|
||
return cfg
|
||
}
|
||
|
||
// isPanelStateNotFound detects "panel state not found" gRPC errors.
|
||
func isPanelStateNotFound(err error) bool {
|
||
if err == nil {
|
||
return false
|
||
}
|
||
s := strings.ToLower(err.Error())
|
||
return strings.Contains(s, "panel state not found") ||
|
||
strings.Contains(s, "not_found") && strings.Contains(s, "panel")
|
||
}
|
||
|
||
// NativeToolCall holds a structured tool call extracted from trajectory step metadata
|
||
// or from step oneof fields (tool_call_proposal, mcp_tool).
|
||
type NativeToolCall struct {
|
||
ID string
|
||
Name string
|
||
ArgumentsJSON string
|
||
}
|
||
|
||
// TrajectoryStep holds the parsed content from a trajectory step.
|
||
type TrajectoryStep struct {
|
||
Type int
|
||
Status int
|
||
Text string // modifiedText || responseText (final preferred)
|
||
ResponseText string // raw responseText (field 20/1) — monotonic during streaming
|
||
Thinking string // field 20/3
|
||
ErrorText string // field 24 or field 31
|
||
Usage *StepUsage
|
||
ToolCall *NativeToolCall // structured tool call from metadata/step oneof
|
||
}
|
||
|
||
// StepUsage holds server-reported token counts from step metadata.
|
||
type StepUsage struct {
|
||
InputTokens int
|
||
OutputTokens int
|
||
CacheReadTokens int
|
||
CacheWriteTokens int
|
||
}
|
||
|
||
// GetTrajectoryStatus polls the trajectory status (field 2 varint).
|
||
// Returns 1 when the trajectory is IDLE (complete).
|
||
func (l *LocalLSClient) GetTrajectoryStatus(ctx context.Context, cascadeID string) (int, error) {
|
||
body := encodeStringField(1, cascadeID)
|
||
resp, err := l.grpcUnaryRaw(ctx, GetCascadeTrajectoryStatusRPC, body)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
status, _ := parseVarintField2(resp)
|
||
return int(status), nil
|
||
}
|
||
|
||
// GetTrajectorySteps fetches trajectory steps starting at stepOffset.
|
||
func (l *LocalLSClient) GetTrajectorySteps(ctx context.Context, cascadeID string, stepOffset int) ([]TrajectoryStep, error) {
|
||
body := encodeStringField(1, cascadeID)
|
||
if stepOffset > 0 {
|
||
body = append(body, encodeVarintField(2, uint64(stepOffset))...)
|
||
}
|
||
resp, err := l.grpcUnaryRaw(ctx, GetCascadeTrajectoryStepsRPC, body)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return parseTrajectorySteps(resp), nil
|
||
}
|
||
|
||
// CascadeChatResult holds the full output from StreamCascadeChat.
|
||
type CascadeChatResult struct {
|
||
Text string
|
||
Thinking string
|
||
Usage *StepUsage // aggregated from all steps; nil if no server-reported data
|
||
CascadeID string
|
||
FirstTextAt time.Time // when text first appeared (zero if no text)
|
||
ToolCalls []NativeToolCall
|
||
}
|
||
|
||
// CascadeModelError is raised when the trajectory contains an error step (type=17)
|
||
// or the planner stalls. Callers should retry with a different account.
|
||
type CascadeModelError struct {
|
||
Msg string
|
||
}
|
||
|
||
func (e *CascadeModelError) Error() string { return e.Msg }
|
||
|
||
// StreamCascadeChat performs the full Cascade chat flow and returns accumulated text + thinking.
|
||
// Includes cold/warm stall detection, step error handling, and final sweep (aligned with JS v1.9).
|
||
// If reuseCascadeID is non-empty, skips StartCascade and reuses the existing cascade session.
|
||
// images 作为当前 user turn 的图像 sidecar 传递给 SendUserCascadeMessage(proto field 6)。
|
||
func (l *LocalLSClient) StreamCascadeChat(ctx context.Context, token, modelUID, userText, toolPreamble, reuseCascadeID string, modelEnumHint int, images []CascadeImage) (*CascadeChatResult, error) {
|
||
if err := l.WarmupCascade(ctx, token); err != nil {
|
||
return nil, fmt.Errorf("warmup: %w", err)
|
||
}
|
||
|
||
var cascadeID string
|
||
var err error
|
||
if reuseCascadeID != "" {
|
||
cascadeID = reuseCascadeID
|
||
} else {
|
||
cascadeID, err = l.StartCascade(ctx, token)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
|
||
// When reusing a cascade, capture the pre-existing step count so subsequent
|
||
// polls only fetch new steps. Without this, Turn 2 would re-read Turn 1's
|
||
// completed steps and append them again to the accumulated text, causing
|
||
// the response to duplicate Turn 1's prefix (including prior <tool_call>).
|
||
startStepIndex := 0
|
||
if reuseCascadeID != "" {
|
||
baselineSteps, berr := l.GetTrajectorySteps(ctx, cascadeID, 0)
|
||
if berr == nil {
|
||
startStepIndex = len(baselineSteps)
|
||
}
|
||
}
|
||
|
||
// allowRecreate=true 仅对本流程内 StartCascade 出来的全新 cascade 安全:
|
||
// 此时 userText 已是 full-history,内部遇到 panel-not-found 可静默重建再发。
|
||
// reuse 场景(caller 传入 reuseCascadeID)下 userText 可能只含最后一条消息,
|
||
// 静默重建会把空状态 cascade 当成有历史的 resume 用 → 上下文丢失,所以禁止。
|
||
allowRecreate := reuseCascadeID == ""
|
||
cascadeID, err = l.SendUserCascadeMessage(ctx, token, cascadeID, userText, modelUID, toolPreamble, modelEnumHint, images, allowRecreate)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("SendUserCascadeMessage: %w", err)
|
||
}
|
||
|
||
const (
|
||
maxWait = 180 * time.Second
|
||
idleGrace = 8 * time.Second
|
||
pollInterval = 250 * time.Millisecond
|
||
noGrowthStallMs = 25000
|
||
stallRetryMinLen = 300
|
||
)
|
||
|
||
textCursors := make(map[int]int)
|
||
thinkCursors := make(map[int]int)
|
||
var totalText, totalThinking int
|
||
var accText, accThinking string
|
||
var firstTextAt time.Time
|
||
idleCount := 0
|
||
sawActive := false
|
||
sawText := false
|
||
lastGrowthAt := time.Now()
|
||
|
||
// Native tool call tracking
|
||
seenToolCalls := make(map[string]bool)
|
||
var nativeToolCalls []NativeToolCall
|
||
lastStatus := 0
|
||
startTime := time.Now()
|
||
deadline := startTime.Add(maxWait)
|
||
graceEnd := startTime.Add(idleGrace)
|
||
inputChars := len(userText)
|
||
|
||
// Aggregated step usage
|
||
usageByStep := make(map[int]*StepUsage)
|
||
|
||
for time.Now().Before(deadline) {
|
||
select {
|
||
case <-ctx.Done():
|
||
return &CascadeChatResult{Text: SanitizePath(accText), Thinking: accThinking, CascadeID: cascadeID, FirstTextAt: firstTextAt, ToolCalls: nativeToolCalls}, ctx.Err()
|
||
default:
|
||
}
|
||
|
||
time.Sleep(pollInterval)
|
||
|
||
steps, err := l.GetTrajectorySteps(ctx, cascadeID, startStepIndex)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
|
||
// Check for error steps (type=17)
|
||
for _, s := range steps {
|
||
if s.Type == 17 && s.ErrorText != "" {
|
||
return nil, &CascadeModelError{Msg: s.ErrorText}
|
||
}
|
||
}
|
||
|
||
// Cold stall: active but no text/thinking after threshold
|
||
elapsed := time.Since(startTime)
|
||
coldThreshold := 30*time.Second + time.Duration(inputChars/1500)*5*time.Second
|
||
if coldThreshold > maxWait {
|
||
coldThreshold = maxWait
|
||
}
|
||
if elapsed > coldThreshold && sawActive && !sawText && totalThinking == 0 {
|
||
return nil, &CascadeModelError{Msg: fmt.Sprintf("Cascade planner stalled — no output after %ds", int(coldThreshold.Seconds()))}
|
||
}
|
||
|
||
for idx, s := range steps {
|
||
// Usage
|
||
if s.Usage != nil {
|
||
usageByStep[idx] = s.Usage
|
||
}
|
||
|
||
// Thinking delta
|
||
if s.Thinking != "" {
|
||
prev := thinkCursors[idx]
|
||
if len(s.Thinking) > prev {
|
||
accThinking += s.Thinking[prev:]
|
||
totalThinking += len(s.Thinking) - prev
|
||
thinkCursors[idx] = len(s.Thinking)
|
||
lastGrowthAt = time.Now()
|
||
}
|
||
}
|
||
|
||
// Native tool call from structured step data
|
||
if s.ToolCall != nil && s.ToolCall.Name != "" {
|
||
key := s.ToolCall.Name + "|" + s.ToolCall.ID
|
||
if !seenToolCalls[key] {
|
||
seenToolCalls[key] = true
|
||
nativeToolCalls = append(nativeToolCalls, *s.ToolCall)
|
||
lastGrowthAt = time.Now()
|
||
sawText = true
|
||
if firstTextAt.IsZero() {
|
||
firstTextAt = lastGrowthAt
|
||
}
|
||
}
|
||
}
|
||
|
||
// Text delta — use ResponseText during streaming for monotonic cursor
|
||
liveText := s.ResponseText
|
||
if liveText == "" {
|
||
liveText = s.Text
|
||
}
|
||
if liveText == "" {
|
||
continue
|
||
}
|
||
prev := textCursors[idx]
|
||
if len(liveText) > prev {
|
||
accText += liveText[prev:]
|
||
totalText += len(liveText) - prev
|
||
textCursors[idx] = len(liveText)
|
||
lastGrowthAt = time.Now()
|
||
if !sawText {
|
||
firstTextAt = lastGrowthAt
|
||
}
|
||
sawText = true
|
||
}
|
||
}
|
||
|
||
// Warm stall: text stopped growing for 25s while planner is active
|
||
if sawText && lastStatus != 1 && time.Since(lastGrowthAt).Milliseconds() > noGrowthStallMs {
|
||
if totalText < stallRetryMinLen {
|
||
return nil, &CascadeModelError{Msg: "Cascade planner stalled after preamble — no progress for 25s"}
|
||
}
|
||
break // accept partial result
|
||
}
|
||
|
||
status, err := l.GetTrajectoryStatus(ctx, cascadeID)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
lastStatus = status
|
||
|
||
if status != 1 {
|
||
sawActive = true
|
||
}
|
||
|
||
if status == 1 { // IDLE
|
||
if !sawActive && time.Now().Before(graceEnd) {
|
||
continue
|
||
}
|
||
idleCount++
|
||
growthSettled := time.Since(lastGrowthAt) > pollInterval*2
|
||
canBreak := false
|
||
if sawText {
|
||
canBreak = idleCount >= 2 && growthSettled
|
||
} else {
|
||
canBreak = idleCount >= 4
|
||
}
|
||
if canBreak {
|
||
// Final sweep: fetch one more time to get modifiedText top-up
|
||
finalSteps, err := l.GetTrajectorySteps(ctx, cascadeID, startStepIndex)
|
||
if err == nil {
|
||
for idx, s := range finalSteps {
|
||
if s.Usage != nil {
|
||
usageByStep[idx] = s.Usage
|
||
}
|
||
// Top up from responseText
|
||
rt := s.ResponseText
|
||
if rt == "" {
|
||
rt = s.Text
|
||
}
|
||
prev := textCursors[idx]
|
||
if len(rt) > prev {
|
||
accText += rt[prev:]
|
||
totalText += len(rt) - prev
|
||
textCursors[idx] = len(rt)
|
||
}
|
||
// Modified-response top-up: only if it extends what we already emitted
|
||
mt := s.Text // Text = modifiedText || responseText
|
||
cursor := textCursors[idx]
|
||
if len(mt) > cursor && strings.HasPrefix(mt, rt) {
|
||
accText += mt[cursor:]
|
||
totalText += len(mt) - cursor
|
||
textCursors[idx] = len(mt)
|
||
}
|
||
// Thinking final sweep
|
||
if s.Thinking != "" {
|
||
prev := thinkCursors[idx]
|
||
if len(s.Thinking) > prev {
|
||
accThinking += s.Thinking[prev:]
|
||
totalThinking += len(s.Thinking) - prev
|
||
thinkCursors[idx] = len(s.Thinking)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
break
|
||
}
|
||
} else {
|
||
idleCount = 0
|
||
}
|
||
}
|
||
|
||
slog.Info("windsurf_cascade_poll_result",
|
||
"cascade_id", cascadeID[:min(8, len(cascadeID))],
|
||
"acc_text_len", len(accText),
|
||
"acc_thinking_len", len(accThinking),
|
||
"native_tool_calls", len(nativeToolCalls),
|
||
"saw_active", sawActive,
|
||
"saw_text", sawText,
|
||
"steps_seen", len(textCursors),
|
||
"idle_count", idleCount,
|
||
)
|
||
|
||
// Aggregate step usage
|
||
var aggUsage *StepUsage
|
||
for _, u := range usageByStep {
|
||
if aggUsage == nil {
|
||
aggUsage = &StepUsage{}
|
||
}
|
||
aggUsage.InputTokens += u.InputTokens
|
||
aggUsage.OutputTokens += u.OutputTokens
|
||
aggUsage.CacheReadTokens += u.CacheReadTokens
|
||
aggUsage.CacheWriteTokens += u.CacheWriteTokens
|
||
}
|
||
|
||
return &CascadeChatResult{
|
||
Text: SanitizePath(accText),
|
||
Thinking: accThinking,
|
||
Usage: aggUsage,
|
||
CascadeID: cascadeID,
|
||
FirstTextAt: firstTextAt,
|
||
ToolCalls: nativeToolCalls,
|
||
}, nil
|
||
}
|
||
|
||
// ── gRPC helpers ───────────────────────────────────────────
|
||
|
||
func (l *LocalLSClient) grpcUnary(ctx context.Context, path string, body []byte) error {
|
||
_, err := l.grpcUnaryRaw(ctx, path, body)
|
||
return err
|
||
}
|
||
|
||
func (l *LocalLSClient) grpcUnaryRaw(ctx context.Context, path string, body []byte) ([]byte, error) {
|
||
env := make([]byte, 5+len(body))
|
||
env[0] = 0
|
||
binary.BigEndian.PutUint32(env[1:5], uint32(len(body)))
|
||
copy(env[5:], body)
|
||
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, l.BaseURL+path, bytes.NewReader(env))
|
||
if err != nil {
|
||
return nil, fmt.Errorf("build request: %w", err)
|
||
}
|
||
req.Header.Set("Content-Type", "application/grpc")
|
||
req.Header.Set("TE", "trailers")
|
||
req.Header.Set("User-Agent", "grpc-go/1.64.0")
|
||
if l.CSRFToken != "" {
|
||
req.Header.Set("x-codeium-csrf-token", l.CSRFToken)
|
||
}
|
||
slog.Debug("windsurf_grpc_request", "url", l.BaseURL+path, "csrf_token", l.CSRFToken)
|
||
|
||
resp, err := l.HTTP.Do(req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("roundtrip: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
respBody, _ := io.ReadAll(resp.Body)
|
||
|
||
if resp.StatusCode != 200 {
|
||
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, truncate(string(respBody), 200))
|
||
}
|
||
|
||
grpcStatus := resp.Header.Get("grpc-status")
|
||
grpcMsg := resp.Header.Get("grpc-message")
|
||
if grpcStatus == "" {
|
||
grpcStatus = resp.Trailer.Get("grpc-status")
|
||
grpcMsg = resp.Trailer.Get("grpc-message")
|
||
}
|
||
if grpcStatus != "" && grpcStatus != "0" {
|
||
slog.Warn("windsurf_grpc_error",
|
||
"url", l.BaseURL+path,
|
||
"grpc_status", grpcStatus,
|
||
"grpc_msg", grpcMsg,
|
||
"http_status", resp.StatusCode,
|
||
"resp_headers", fmt.Sprintf("%v", resp.Header),
|
||
"resp_trailers", fmt.Sprintf("%v", resp.Trailer),
|
||
"body_len", len(respBody),
|
||
)
|
||
decoded, decErr := url.QueryUnescape(grpcMsg)
|
||
if decErr == nil {
|
||
grpcMsg = decoded
|
||
}
|
||
return nil, fmt.Errorf("gRPC status %s: %s", grpcStatus, grpcMsg)
|
||
}
|
||
|
||
return stripGRPCFrame(respBody), nil
|
||
}
|
||
|
||
func stripGRPCFrame(data []byte) []byte {
|
||
if len(data) < 5 {
|
||
return data
|
||
}
|
||
msgLen := binary.BigEndian.Uint32(data[1:5])
|
||
if 5+int(msgLen) <= len(data) {
|
||
return data[5 : 5+msgLen]
|
||
}
|
||
return data[5:]
|
||
}
|
||
|
||
// ── Model enum mapping ─────────────────────────────────────
|
||
|
||
// modelEnumByUID maps modelUid strings to their deprecated enum values.
|
||
// Only entries with enumValue > 0 are included. Sourced from WindsurfAPI models.js.
|
||
var modelEnumByUID = map[string]int{
|
||
// Anthropic
|
||
"MODEL_CLAUDE_4_SONNET": 281,
|
||
"MODEL_CLAUDE_4_SONNET_THINKING": 282,
|
||
"MODEL_CLAUDE_4_OPUS": 290,
|
||
"MODEL_CLAUDE_4_OPUS_THINKING": 291,
|
||
"MODEL_CLAUDE_4_1_OPUS": 328,
|
||
"MODEL_CLAUDE_4_1_OPUS_THINKING": 329,
|
||
"MODEL_PRIVATE_2": 353,
|
||
"MODEL_PRIVATE_3": 354,
|
||
"MODEL_CLAUDE_4_5_OPUS": 391,
|
||
"MODEL_CLAUDE_4_5_OPUS_THINKING": 392,
|
||
// OpenAI
|
||
"MODEL_CHAT_GPT_4O_2024_08_06": 109,
|
||
"MODEL_CHAT_GPT_4_1_2025_04_14": 259,
|
||
"MODEL_PRIVATE_6": 340,
|
||
"MODEL_CHAT_GPT_5_CODEX": 346,
|
||
"MODEL_GPT_5_2_LOW": 400,
|
||
"MODEL_GPT_5_2_MEDIUM": 401,
|
||
"MODEL_GPT_5_2_HIGH": 402,
|
||
"MODEL_GPT_5_2_XHIGH": 403,
|
||
"MODEL_CHAT_O3": 218,
|
||
// Google
|
||
"MODEL_GOOGLE_GEMINI_2_5_PRO": 246,
|
||
"MODEL_GOOGLE_GEMINI_2_5_FLASH": 312,
|
||
"MODEL_GOOGLE_GEMINI_3_0_PRO_LOW": 412,
|
||
"MODEL_GOOGLE_GEMINI_3_0_FLASH_MEDIUM": 415,
|
||
// Others
|
||
"MODEL_XAI_GROK_3": 217,
|
||
"MODEL_KIMI_K2": 323,
|
||
"MODEL_GLM_4_7": 417,
|
||
"MODEL_SWE_1_5_SLOW": 369,
|
||
"MODEL_SWE_1_5": 359,
|
||
}
|
||
|
||
func resolveModelEnum(modelUID string) int {
|
||
if v, ok := modelEnumByUID[modelUID]; ok {
|
||
return v
|
||
}
|
||
return 0
|
||
}
|
||
|
||
// ── Proto parsers ──────────────────────────────────────────
|
||
|
||
func parseStringField1(data []byte) (string, error) {
|
||
pos := 0
|
||
for pos < len(data) {
|
||
tag, np, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np
|
||
fieldNum := tag >> 3
|
||
wireType := tag & 7
|
||
|
||
switch wireType {
|
||
case 2:
|
||
length, np2, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
return "", fmt.Errorf("parse length at pos %d", pos)
|
||
}
|
||
pos = np2
|
||
if pos+int(length) > len(data) {
|
||
return "", fmt.Errorf("field out of bounds")
|
||
}
|
||
field := data[pos : pos+int(length)]
|
||
pos += int(length)
|
||
if fieldNum == 1 {
|
||
return string(field), nil
|
||
}
|
||
case 0:
|
||
_, np2, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np2
|
||
case 1:
|
||
pos += 8
|
||
case 5:
|
||
pos += 4
|
||
default:
|
||
return "", fmt.Errorf("unknown wire type %d at pos %d", wireType, pos)
|
||
}
|
||
}
|
||
return "", nil
|
||
}
|
||
|
||
func parseVarintField2(data []byte) (uint64, error) {
|
||
pos := 0
|
||
for pos < len(data) {
|
||
tag, np, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np
|
||
fieldNum := tag >> 3
|
||
wireType := tag & 7
|
||
|
||
switch wireType {
|
||
case 0:
|
||
val, np2, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np2
|
||
if fieldNum == 2 {
|
||
return val, nil
|
||
}
|
||
case 2:
|
||
length, np2, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np2 + int(length)
|
||
case 1:
|
||
pos += 8
|
||
case 5:
|
||
pos += 4
|
||
default:
|
||
break
|
||
}
|
||
}
|
||
return 0, nil
|
||
}
|
||
|
||
func parseTrajectorySteps(data []byte) []TrajectoryStep {
|
||
var steps []TrajectoryStep
|
||
pos := 0
|
||
for pos < len(data) {
|
||
tag, np, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np
|
||
fieldNum := tag >> 3
|
||
wireType := tag & 7
|
||
|
||
if wireType == 2 {
|
||
length, np2, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np2
|
||
if pos+int(length) > len(data) {
|
||
break
|
||
}
|
||
field := data[pos : pos+int(length)]
|
||
pos += int(length)
|
||
if fieldNum == 1 {
|
||
steps = append(steps, parseOneTrajectoryStep(field))
|
||
}
|
||
} else if wireType == 0 {
|
||
_, np2, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np2
|
||
} else if wireType == 1 {
|
||
pos += 8
|
||
} else if wireType == 5 {
|
||
pos += 4
|
||
} else {
|
||
break
|
||
}
|
||
}
|
||
return steps
|
||
}
|
||
|
||
func parseOneTrajectoryStep(data []byte) TrajectoryStep {
|
||
var s TrajectoryStep
|
||
pos := 0
|
||
for pos < len(data) {
|
||
tag, np, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np
|
||
fieldNum := tag >> 3
|
||
wireType := tag & 7
|
||
|
||
switch wireType {
|
||
case 0:
|
||
val, np2, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np2
|
||
switch fieldNum {
|
||
case 1:
|
||
s.Type = int(val)
|
||
case 4:
|
||
s.Status = int(val)
|
||
}
|
||
case 2:
|
||
length, np2, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np2
|
||
if pos+int(length) > len(data) {
|
||
break
|
||
}
|
||
field := data[pos : pos+int(length)]
|
||
pos += int(length)
|
||
switch fieldNum {
|
||
case 5: // step metadata (CortexStepMetadata)
|
||
s.Usage = parseStepUsage(field)
|
||
if tc := parseMetadataToolCall(field); tc != nil {
|
||
s.ToolCall = tc
|
||
}
|
||
case 47: // mcp_tool (CortexStepMcpTool) — tool_call is field 2
|
||
if tc := parseChatToolCallFromContainer(field, 2); tc != nil {
|
||
s.ToolCall = tc
|
||
}
|
||
case 49: // tool_call_proposal (CortexStepToolCallProposal) — tool_call is field 1
|
||
if tc := parseChatToolCallFromContainer(field, 1); tc != nil {
|
||
s.ToolCall = tc
|
||
}
|
||
case 20: // planner_response
|
||
pr := parseFields2(field)
|
||
var responseText, modifiedText, thinking string
|
||
for _, pf := range pr {
|
||
switch pf.fn {
|
||
case 1:
|
||
responseText = string(pf.val)
|
||
case 3:
|
||
thinking = string(pf.val)
|
||
case 8:
|
||
modifiedText = string(pf.val)
|
||
}
|
||
}
|
||
if modifiedText != "" {
|
||
s.Text = modifiedText
|
||
} else {
|
||
s.Text = responseText
|
||
}
|
||
s.ResponseText = responseText
|
||
s.Thinking = thinking
|
||
case 24: // error_message
|
||
s.ErrorText = extractErrorText(field)
|
||
case 31: // error (fallback)
|
||
if s.ErrorText == "" {
|
||
s.ErrorText = extractErrorText(field)
|
||
}
|
||
}
|
||
case 1:
|
||
pos += 8
|
||
case 5:
|
||
pos += 4
|
||
default:
|
||
pos = len(data)
|
||
}
|
||
}
|
||
return s
|
||
}
|
||
|
||
// parseChatToolCall parses a ChatToolCall proto message:
|
||
//
|
||
// field 1 (string) = id
|
||
// field 2 (string) = name
|
||
// field 3 (string) = arguments_json
|
||
func parseChatToolCall(data []byte) *NativeToolCall {
|
||
var tc NativeToolCall
|
||
pos := 0
|
||
for pos < len(data) {
|
||
tag, np, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np
|
||
fieldNum := tag >> 3
|
||
wireType := tag & 7
|
||
if wireType == 2 {
|
||
length, np2, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np2
|
||
if pos+int(length) > len(data) {
|
||
break
|
||
}
|
||
val := string(data[pos : pos+int(length)])
|
||
pos += int(length)
|
||
switch fieldNum {
|
||
case 1:
|
||
tc.ID = val
|
||
case 2:
|
||
tc.Name = val
|
||
case 3:
|
||
tc.ArgumentsJSON = val
|
||
}
|
||
} else if wireType == 0 {
|
||
_, np2, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np2
|
||
} else if wireType == 1 {
|
||
pos += 8
|
||
} else if wireType == 5 {
|
||
pos += 4
|
||
} else {
|
||
break
|
||
}
|
||
}
|
||
if tc.Name == "" {
|
||
return nil
|
||
}
|
||
return &tc
|
||
}
|
||
|
||
// parseChatToolCallFromContainer extracts ChatToolCall from a container message
|
||
// where the ChatToolCall is at the given field number.
|
||
func parseChatToolCallFromContainer(data []byte, toolCallFieldNum uint64) *NativeToolCall {
|
||
pos := 0
|
||
for pos < len(data) {
|
||
tag, np, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np
|
||
fieldNum := tag >> 3
|
||
wireType := tag & 7
|
||
if wireType == 2 {
|
||
length, np2, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np2
|
||
if pos+int(length) > len(data) {
|
||
break
|
||
}
|
||
field := data[pos : pos+int(length)]
|
||
pos += int(length)
|
||
if fieldNum == toolCallFieldNum {
|
||
return parseChatToolCall(field)
|
||
}
|
||
} else if wireType == 0 {
|
||
_, np2, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np2
|
||
} else if wireType == 1 {
|
||
pos += 8
|
||
} else if wireType == 5 {
|
||
pos += 4
|
||
} else {
|
||
break
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// parseMetadataToolCall extracts ChatToolCall from CortexStepMetadata (field 4 = tool_call).
|
||
func parseMetadataToolCall(metaData []byte) *NativeToolCall {
|
||
return parseChatToolCallFromContainer(metaData, 4)
|
||
}
|
||
|
||
// parseStepUsage extracts token usage from CortexStepMetadata (field 5).
|
||
// CortexStepMetadata.model_usage = field 9 → ModelUsageStats {2=input, 3=output, 4=cacheWrite, 5=cacheRead}
|
||
func parseStepUsage(metaData []byte) *StepUsage {
|
||
// Find field 9 (model_usage) in metadata
|
||
usageData := extractLenDelimField(metaData, 9)
|
||
if usageData == nil {
|
||
return nil
|
||
}
|
||
var u StepUsage
|
||
found := false
|
||
pos := 0
|
||
for pos < len(usageData) {
|
||
tag, np, ok := ReadVarint(usageData, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np
|
||
fn := tag >> 3
|
||
wt := tag & 7
|
||
if wt == 0 {
|
||
val, np2, ok := ReadVarint(usageData, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np2
|
||
switch fn {
|
||
case 2:
|
||
u.InputTokens = int(val)
|
||
found = true
|
||
case 3:
|
||
u.OutputTokens = int(val)
|
||
found = true
|
||
case 4:
|
||
u.CacheWriteTokens = int(val)
|
||
found = true
|
||
case 5:
|
||
u.CacheReadTokens = int(val)
|
||
found = true
|
||
}
|
||
} else if wt == 2 {
|
||
length, np2, ok := ReadVarint(usageData, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np2 + int(length)
|
||
} else if wt == 1 {
|
||
pos += 8
|
||
} else if wt == 5 {
|
||
pos += 4
|
||
} else {
|
||
break
|
||
}
|
||
}
|
||
if !found {
|
||
return nil
|
||
}
|
||
return &u
|
||
}
|
||
|
||
// extractLenDelimField finds the first length-delimited field with the given number.
|
||
func extractLenDelimField(data []byte, targetField uint64) []byte {
|
||
pos := 0
|
||
for pos < len(data) {
|
||
tag, np, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np
|
||
fn := tag >> 3
|
||
wt := tag & 7
|
||
if wt == 2 {
|
||
length, np2, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np2
|
||
if pos+int(length) > len(data) {
|
||
break
|
||
}
|
||
if fn == targetField {
|
||
return data[pos : pos+int(length)]
|
||
}
|
||
pos += int(length)
|
||
} else if wt == 0 {
|
||
_, np2, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np2
|
||
} else if wt == 1 {
|
||
pos += 8
|
||
} else if wt == 5 {
|
||
pos += 4
|
||
} else {
|
||
break
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
type protoField struct {
|
||
fn uint64
|
||
val []byte
|
||
}
|
||
|
||
func parseFields2(data []byte) []protoField {
|
||
var fields []protoField
|
||
pos := 0
|
||
for pos < len(data) {
|
||
tag, np, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np
|
||
fn := tag >> 3
|
||
wt := tag & 7
|
||
switch wt {
|
||
case 2:
|
||
length, np2, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np2
|
||
if pos+int(length) > len(data) {
|
||
break
|
||
}
|
||
fields = append(fields, protoField{fn, data[pos : pos+int(length)]})
|
||
pos += int(length)
|
||
case 0:
|
||
_, np2, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np2
|
||
case 1:
|
||
pos += 8
|
||
case 5:
|
||
pos += 4
|
||
default:
|
||
pos = len(data)
|
||
}
|
||
}
|
||
return fields
|
||
}
|
||
|
||
func extractErrorText(data []byte) string {
|
||
return extractErrorTextDepth(data, 0)
|
||
}
|
||
|
||
func extractErrorTextDepth(data []byte, depth int) string {
|
||
if depth > 3 {
|
||
return ""
|
||
}
|
||
for _, pf := range parseFields2(data) {
|
||
if pf.fn >= 1 && pf.fn <= 5 && len(pf.val) > 10 {
|
||
txt := string(pf.val)
|
||
for len(txt) > 0 && txt[0] < 0x20 {
|
||
txt = txt[1:]
|
||
}
|
||
if len(txt) > 10 && !hasNonPrintable(txt[:10]) {
|
||
return txt
|
||
}
|
||
if inner := extractErrorTextDepth(pf.val, depth+1); inner != "" {
|
||
return inner
|
||
}
|
||
}
|
||
}
|
||
return ""
|
||
}
|
||
|
||
func hasNonPrintable(s string) bool {
|
||
for _, c := range s {
|
||
if c < 0x20 && c != '\n' && c != '\r' {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// GetCascadeModelConfigs 查询 LS 的 GetCascadeModelConfigs RPC,
|
||
// 返回 model_name -> supports_images 的映射。模型名按小写归一化。
|
||
func (l *LocalLSClient) GetCascadeModelConfigs(ctx context.Context, token string) (map[string]bool, error) {
|
||
// 请求 body:只需 metadata,即 field 1 encode(Metadata)
|
||
// 参考 package.json 提到的 proto;这里用 metadata-only encoding 与其他 RPC 一致。
|
||
body := encodeBytesField(1, buildMetadata(token, l.SessionID))
|
||
raw, err := l.grpcUnaryRaw(ctx, GetCascadeModelConfigsRPC, body)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("get_cascade_model_configs: %w", err)
|
||
}
|
||
var resp pb.GetCascadeModelConfigsResponse
|
||
if err := proto.Unmarshal(raw, &resp); err != nil {
|
||
return nil, fmt.Errorf("unmarshal: %w", err)
|
||
}
|
||
out := make(map[string]bool, len(resp.GetModels()))
|
||
for _, m := range resp.GetModels() {
|
||
out[strings.ToLower(strings.TrimSpace(m.GetName()))] = m.GetSupportsImages()
|
||
}
|
||
return out, nil
|
||
}
|
||
|
||
// ModelSupportsImages 带缓存的图像能力查询。
|
||
// fail-open:RPC 失败且无缓存时返回 (false, false, nil),由上层决定策略。
|
||
// 返回值:(found, supportsImages, error)
|
||
func (l *LocalLSClient) ModelSupportsImages(ctx context.Context, token, modelName string) (bool, bool, error) {
|
||
key := apiKeyHash(token)
|
||
|
||
l.modelCapsMu.Lock()
|
||
if l.modelCapsCache == nil {
|
||
l.modelCapsCache = make(map[string]cascadeModelCapsCacheEntry)
|
||
}
|
||
entry, ok := l.modelCapsCache[key]
|
||
fresh := ok && time.Since(entry.FetchedAt) < cascadeModelCapsTTL
|
||
l.modelCapsMu.Unlock()
|
||
|
||
if fresh {
|
||
v, found := entry.SupportsImages[strings.ToLower(strings.TrimSpace(modelName))]
|
||
return found, v, nil
|
||
}
|
||
|
||
// 拉新:失败时保留 stale
|
||
caps, err := l.GetCascadeModelConfigs(ctx, token)
|
||
if err != nil {
|
||
// stale fallback
|
||
if ok {
|
||
v, found := entry.SupportsImages[strings.ToLower(strings.TrimSpace(modelName))]
|
||
return found, v, nil
|
||
}
|
||
return false, false, err
|
||
}
|
||
|
||
l.modelCapsMu.Lock()
|
||
l.modelCapsCache[key] = cascadeModelCapsCacheEntry{
|
||
SupportsImages: caps,
|
||
FetchedAt: time.Now(),
|
||
}
|
||
l.modelCapsMu.Unlock()
|
||
|
||
v, found := caps[strings.ToLower(strings.TrimSpace(modelName))]
|
||
return found, v, nil
|
||
}
|
||
|
||
func apiKeyHash(token string) string {
|
||
sum := sha256.Sum256([]byte(token))
|
||
return hex.EncodeToString(sum[:8]) // 16 hex chars 足够区分
|
||
}
|