1355 lines
39 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// LocalLS provides a gRPC client for the local Windsurf LanguageServerService.
//
// The correct chat flow routes through the local LS binary (which handles
// all auth/session management internally) rather than calling the upstream
// ApiServerService directly. The Cascade flow:
//
// 1. InitializeCascadePanelState — one-shot per LS session
// 2. AddTrackedWorkspace — one-shot per LS session
// 3. UpdateWorkspaceTrust — one-shot per LS session
// 4. StartCascade → cascade_id
// 5. SendUserCascadeMessage — send prompt + model config
// 6. GetCascadeTrajectorySteps — poll until trajectory status is IDLE (2)
package windsurf
import (
"bytes"
"context"
"crypto/sha256"
"crypto/tls"
"encoding/binary"
"encoding/hex"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
"golang.org/x/net/http2"
"google.golang.org/protobuf/proto"
pb "github.com/Wei-Shaw/sub2api/internal/gen/language_server_pb"
)
const (
StartCascadeRPC = "/exa.language_server_pb.LanguageServerService/StartCascade"
InitPanelStateRPC = "/exa.language_server_pb.LanguageServerService/InitializeCascadePanelState"
AddTrackedWorkspaceRPC = "/exa.language_server_pb.LanguageServerService/AddTrackedWorkspace"
UpdateWorkspaceTrustRPC = "/exa.language_server_pb.LanguageServerService/UpdateWorkspaceTrust"
SendUserCascadeMessageRPC = "/exa.language_server_pb.LanguageServerService/SendUserCascadeMessage"
GetCascadeTrajectoryStepsRPC = "/exa.language_server_pb.LanguageServerService/GetCascadeTrajectorySteps"
GetCascadeTrajectoryStatusRPC = "/exa.language_server_pb.LanguageServerService/GetCascadeTrajectory"
GetCascadeModelConfigsRPC = "/exa.language_server_pb.LanguageServerService/GetCascadeModelConfigs"
)
// cascadeModelCapsCacheEntry 是单个 API key 下模型能力的缓存条目。
type cascadeModelCapsCacheEntry struct {
SupportsImages map[string]bool
FetchedAt time.Time
}
// cascadeModelCapsTTL 能力缓存 TTL5 分钟)。
const cascadeModelCapsTTL = 5 * time.Minute
// LocalLSClient talks to the local Windsurf LanguageServerService via h2c (plain HTTP/2 over TCP).
type LocalLSClient struct {
BaseURL string
CSRFToken string
HTTP *http.Client
SessionID string
Warmed bool
// TrackedWorkspace is optional. When empty, the LS is treated as having no
// server-side repository context and relies on caller-provided tool results.
TrackedWorkspace string
mu sync.Mutex
// 模型能力缓存per-API-key hash供 Cascade 图像 gate 使用。
modelCapsMu sync.Mutex
modelCapsCache map[string]cascadeModelCapsCacheEntry
}
// NewLocalLSClient builds a client for the local LS at the given port.
func NewLocalLSClient(port int, csrfToken string) *LocalLSClient {
h2cTransport := &http2.Transport{
AllowHTTP: true,
DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
return (&net.Dialer{Timeout: 5 * time.Second}).DialContext(ctx, network, addr)
},
}
return &LocalLSClient{
BaseURL: fmt.Sprintf("http://localhost:%d", port),
CSRFToken: csrfToken,
SessionID: generateUUID(),
TrackedWorkspace: "",
HTTP: &http.Client{
Transport: h2cTransport,
Timeout: 60 * time.Second,
},
}
}
// WarmupCascade runs the one-shot panel init sequence required before StartCascade.
// Idempotent — skip if already warmed.
func (l *LocalLSClient) WarmupCascade(ctx context.Context, token string) error {
return l.warmupCascade(ctx, token, false)
}
// ForceWarmupCascade resets session state and re-runs warmup.
func (l *LocalLSClient) ForceWarmupCascade(ctx context.Context, token string) error {
return l.warmupCascade(ctx, token, true)
}
func (l *LocalLSClient) warmupCascade(ctx context.Context, token string, force bool) error {
l.mu.Lock()
defer l.mu.Unlock()
if force {
l.Warmed = false
l.SessionID = generateUUID()
}
if l.Warmed {
return nil
}
if l.SessionID == "" {
l.SessionID = generateUUID()
}
var firstErr error
// InitializeCascadePanelState: F1=metadata, F3=workspace_trusted (bool, true)
initReq := encodeBytesField(1, buildMetadata(token, l.SessionID))
initReq = append(initReq, encodeVarintField(3, 1)...)
if err := l.grpcUnary(ctx, InitPanelStateRPC, initReq); err != nil {
firstErr = err
}
// AddTrackedWorkspace is optional. Default Windsurf mode should not pretend
// to have a mounted repository when the server does not actually have one.
if strings.TrimSpace(l.TrackedWorkspace) != "" {
addWsReq := encodeStringField(1, l.TrackedWorkspace)
_ = l.grpcUnary(ctx, AddTrackedWorkspaceRPC, addWsReq)
}
// UpdateWorkspaceTrust: F1=metadata, F2=workspace_trusted (bool, true)
trustReq := encodeBytesField(1, buildMetadata(token, l.SessionID))
trustReq = append(trustReq, encodeVarintField(2, 1)...)
if err := l.grpcUnary(ctx, UpdateWorkspaceTrustRPC, trustReq); err != nil && firstErr == nil {
firstErr = err
}
// Only mark warmed on success (unlike the old code which always set true)
if firstErr == nil {
l.Warmed = true
}
return firstErr
}
// StartCascade calls StartCascade and returns the cascade_id.
// Retries once on panel-state-not-found.
func (l *LocalLSClient) StartCascade(ctx context.Context, token string) (string, error) {
doStart := func() (string, error) {
body := encodeBytesField(1, buildMetadata(token, l.SessionID))
resp, err := l.grpcUnaryRaw(ctx, StartCascadeRPC, body)
if err != nil {
return "", fmt.Errorf("StartCascade: %w", err)
}
cascadeID, err := parseStringField1(resp)
if err != nil {
return "", fmt.Errorf("StartCascade parse: %w", err)
}
if cascadeID == "" {
return "", fmt.Errorf("StartCascade: empty cascade_id (hex=%x)", resp)
}
return cascadeID, nil
}
cascadeID, err := doStart()
if err != nil && isPanelStateNotFound(err) {
_ = l.ForceWarmupCascade(ctx, token)
return doStart()
}
return cascadeID, err
}
// SendUserCascadeMessage sends a message into an existing cascade session.
// Returns the (possibly new) cascadeID — it changes if panel-state retry triggers a new StartCascade.
// toolPreamble, if non-empty, is injected into the tool_calling_section override.
// SendUserCascadeMessage sends a user chat message to Cascade.
// Returns the (possibly new) cascadeID — it changes if panel-state retry triggers a new StartCascade.
// toolPreamble, if non-empty, is injected into the tool_calling_section override.
// images可选作为 SendUserCascadeMessageRequest.images (field 6) 追加到 proto wire。
//
// allowRecreate 控制 panel-state-not-found 时是否内部静默重建 cascade
// - trueForceWarmup + StartCascade + 用 SAME text 再发一次。调用方须保证
// text 已包含完整历史(否则新 cascade 无状态 + text 无历史 = 上下文丢失)。
// - false直接返回错误让调用方重建含完整历史的 text 后再调。
//
// 经验值StreamCascadeChat 内当 reuseCascadeID 为空(本地 StartCascade 的流程,
// text 已是 full-history时传 truereuse 场景text 可能仅含最后一条消息)传 false。
func (l *LocalLSClient) SendUserCascadeMessage(ctx context.Context, token, cascadeID, text, modelUID, toolPreamble string, modelEnumHint int, images []CascadeImage, allowRecreate bool) (string, error) {
modelEnum := resolveModelEnum(modelUID)
if modelEnum == 0 && modelEnumHint > 0 {
modelEnum = modelEnumHint
}
doSend := func(cid string) error {
body := encodeStringField(1, cid)
body = append(body, encodeBytesField(2, encodeStringField(1, text))...)
body = append(body, encodeBytesField(3, buildMetadata(token, l.SessionID))...)
body = append(body, encodeBytesField(5, buildCascadeConfig(modelUID, modelEnum, toolPreamble))...)
// field 6: repeated CodeiumImage images逆向自 Windsurf.app chat-client
body = appendSendUserCascadeImages(body, images)
return l.grpcUnary(ctx, SendUserCascadeMessageRPC, body)
}
if err := doSend(cascadeID); err != nil {
if !isPanelStateNotFound(err) {
return "", err
}
if !allowRecreate {
// reuse 场景:不要静默用 last-message-only 的 text 去灌新 cascade。
// 返回错误,让 chatCascade 用 full-history text 重建整个调用。
return "", err
}
_ = l.ForceWarmupCascade(ctx, token)
newCascadeID, startErr := l.StartCascade(ctx, token)
if startErr != nil {
return "", startErr
}
if err := doSend(newCascadeID); err != nil {
return "", err
}
return newCascadeID, nil
}
return cascadeID, nil
}
// buildMetadata builds the full Metadata proto for local LS calls, aligned with WindsurfAPI.
func buildMetadata(token, sessionID string) []byte {
if sessionID == "" {
sessionID = generateUUID()
}
var meta []byte
meta = append(meta, encodeStringField(1, AppName)...) // ide_name
meta = append(meta, encodeStringField(2, ExtensionVersion)...) // extension_version
meta = append(meta, encodeStringField(3, token)...) // api_key
meta = append(meta, encodeStringField(4, "en")...) // locale
meta = append(meta, encodeStringField(5, RuntimeOS())...) // os
meta = append(meta, encodeStringField(7, IDEVersion)...) // ide_version
meta = append(meta, encodeStringField(8, HardwareArch())...) // hardware
meta = append(meta, encodeVarintField(9, uint64(time.Now().UnixMilli()))...) // request_id
meta = append(meta, encodeStringField(10, sessionID)...) // session_id
meta = append(meta, encodeStringField(12, AppName)...) // extension_name
return meta
}
// buildSectionOverride builds a SectionOverrideConfig { mode=OVERRIDE(1), content=text }.
func buildSectionOverride(content string) []byte {
var out []byte
out = append(out, encodeVarintField(1, 1)...) // SECTION_OVERRIDE_MODE_OVERRIDE
out = append(out, encodeStringField(2, content)...)
return out
}
// buildCascadeConfig builds a CascadeConfig for the given model UID and enum.
// Uses NO_TOOL planner mode (3) with section overrides for pure conversational responses.
//
// Key insight (2026-04-12): NO_TOOL mode SUPPRESSES field 10 (tool_calling_section) —
// it is injected but never rendered to the model. Tool definitions MUST go into
// field 12 (additional_instructions_section) which IS rendered regardless of planner mode.
// Field 10 is kept as belt-and-suspenders.
func buildCascadeConfig(modelUID string, modelEnum int, toolPreamble string) []byte {
var convParts []byte
convParts = append(convParts, encodeVarintField(4, 3)...) // planner_mode=NO_TOOL(3)
const toolReinforcement = "\n\nThe functions listed above are available and callable. " +
"When the user's request can be answered by calling a function, emit a <tool_call> block as described. " +
"Use this exact format: <tool_call>{\"name\":\"...\",\"arguments\":{...}}</tool_call>"
if toolPreamble != "" {
// Primary: field 12 (additional_instructions_section) — always rendered in NO_TOOL mode
convParts = append(convParts, encodeBytesField(12, buildSectionOverride(toolPreamble+toolReinforcement))...)
// Belt-and-suspenders: field 10 (tool_calling_section)
convParts = append(convParts, encodeBytesField(10, buildSectionOverride(toolPreamble))...)
// field 13 (communication_section)
convParts = append(convParts, encodeBytesField(13, buildSectionOverride(
"You are accessed via API. Respond in the same language as the user. "+
"Use the functions above when relevant."))...)
} else {
// field 10: suppress built-in tool list
convParts = append(convParts, encodeBytesField(10, buildSectionOverride("No tools are available."))...)
// field 12: reinforce direct-answer mode
convParts = append(convParts, encodeBytesField(12, buildSectionOverride(
"You have no tools, no file access, and no command execution. "+
"Answer all questions directly using your knowledge. "+
"Never pretend to create files or check directories."))...)
// field 11 (code_changes_section): suppress IDE-specific boilerplate
convParts = append(convParts, encodeBytesField(11, buildSectionOverride(""))...)
// field 13 (communication_section)
convParts = append(convParts, encodeBytesField(13, buildSectionOverride(
"You are accessed via API. Answer directly. "+
"Respond in the same language as the user."))...)
}
// CortexPlannerConfig
var plannerParts []byte
plannerParts = append(plannerParts, encodeBytesField(2, convParts)...) // conversational=2
if modelUID != "" {
plannerParts = append(plannerParts, encodeStringField(35, modelUID)...)
plannerParts = append(plannerParts, encodeStringField(34, modelUID)...)
}
if modelEnum > 0 {
plannerParts = append(plannerParts, encodeBytesField(15, encodeVarintField(1, uint64(modelEnum)))...)
plannerParts = append(plannerParts, encodeVarintField(1, uint64(modelEnum))...)
}
// max_output_tokens (field 6) = 32768 — prevents long response truncation
plannerParts = append(plannerParts, encodeVarintField(6, 32768)...)
// BrainConfig: F1=enabled=true, F6=update_strategy{dynamic_update{}}
var brainParts []byte
brainParts = append(brainParts, encodeVarintField(1, 1)...)
brainParts = append(brainParts, encodeBytesField(6, encodeBytesField(6, nil))...)
// memory_config (field 5): {enabled=false} — prevent LS injecting user's stored memories
memoryConfig := encodeVarintField(1, 0) // bool enabled = false
var cfg []byte
cfg = append(cfg, encodeBytesField(1, plannerParts)...)
cfg = append(cfg, encodeBytesField(5, memoryConfig)...)
cfg = append(cfg, encodeBytesField(7, brainParts)...)
return cfg
}
// isPanelStateNotFound detects "panel state not found" gRPC errors.
func isPanelStateNotFound(err error) bool {
if err == nil {
return false
}
s := strings.ToLower(err.Error())
return strings.Contains(s, "panel state not found") ||
strings.Contains(s, "not_found") && strings.Contains(s, "panel")
}
// NativeToolCall holds a structured tool call extracted from trajectory step metadata
// or from step oneof fields (tool_call_proposal, mcp_tool).
type NativeToolCall struct {
ID string
Name string
ArgumentsJSON string
}
// TrajectoryStep holds the parsed content from a trajectory step.
type TrajectoryStep struct {
Type int
Status int
Text string // modifiedText || responseText (final preferred)
ResponseText string // raw responseText (field 20/1) — monotonic during streaming
Thinking string // field 20/3
ErrorText string // field 24 or field 31
Usage *StepUsage
ToolCall *NativeToolCall // structured tool call from metadata/step oneof
}
// StepUsage holds server-reported token counts from step metadata.
type StepUsage struct {
InputTokens int
OutputTokens int
CacheReadTokens int
CacheWriteTokens int
}
// GetTrajectoryStatus polls the trajectory status (field 2 varint).
// Returns 1 when the trajectory is IDLE (complete).
func (l *LocalLSClient) GetTrajectoryStatus(ctx context.Context, cascadeID string) (int, error) {
body := encodeStringField(1, cascadeID)
resp, err := l.grpcUnaryRaw(ctx, GetCascadeTrajectoryStatusRPC, body)
if err != nil {
return 0, err
}
status, _ := parseVarintField2(resp)
return int(status), nil
}
// GetTrajectorySteps fetches trajectory steps starting at stepOffset.
func (l *LocalLSClient) GetTrajectorySteps(ctx context.Context, cascadeID string, stepOffset int) ([]TrajectoryStep, error) {
body := encodeStringField(1, cascadeID)
if stepOffset > 0 {
body = append(body, encodeVarintField(2, uint64(stepOffset))...)
}
resp, err := l.grpcUnaryRaw(ctx, GetCascadeTrajectoryStepsRPC, body)
if err != nil {
return nil, err
}
return parseTrajectorySteps(resp), nil
}
// CascadeChatResult holds the full output from StreamCascadeChat.
type CascadeChatResult struct {
Text string
Thinking string
Usage *StepUsage // aggregated from all steps; nil if no server-reported data
CascadeID string
FirstTextAt time.Time // when text first appeared (zero if no text)
ToolCalls []NativeToolCall
}
// CascadeModelError is raised when the trajectory contains an error step (type=17)
// or the planner stalls. Callers should retry with a different account.
type CascadeModelError struct {
Msg string
}
func (e *CascadeModelError) Error() string { return e.Msg }
// StreamCascadeChat performs the full Cascade chat flow and returns accumulated text + thinking.
// Includes cold/warm stall detection, step error handling, and final sweep (aligned with JS v1.9).
// If reuseCascadeID is non-empty, skips StartCascade and reuses the existing cascade session.
// images 作为当前 user turn 的图像 sidecar 传递给 SendUserCascadeMessageproto field 6
func (l *LocalLSClient) StreamCascadeChat(ctx context.Context, token, modelUID, userText, toolPreamble, reuseCascadeID string, modelEnumHint int, images []CascadeImage) (*CascadeChatResult, error) {
if err := l.WarmupCascade(ctx, token); err != nil {
return nil, fmt.Errorf("warmup: %w", err)
}
var cascadeID string
var err error
if reuseCascadeID != "" {
cascadeID = reuseCascadeID
} else {
cascadeID, err = l.StartCascade(ctx, token)
if err != nil {
return nil, err
}
}
// When reusing a cascade, capture the pre-existing step count so subsequent
// polls only fetch new steps. Without this, Turn 2 would re-read Turn 1's
// completed steps and append them again to the accumulated text, causing
// the response to duplicate Turn 1's prefix (including prior <tool_call>).
startStepIndex := 0
if reuseCascadeID != "" {
baselineSteps, berr := l.GetTrajectorySteps(ctx, cascadeID, 0)
if berr == nil {
startStepIndex = len(baselineSteps)
}
}
// allowRecreate=true 仅对本流程内 StartCascade 出来的全新 cascade 安全:
// 此时 userText 已是 full-history内部遇到 panel-not-found 可静默重建再发。
// reuse 场景caller 传入 reuseCascadeID下 userText 可能只含最后一条消息,
// 静默重建会把空状态 cascade 当成有历史的 resume 用 → 上下文丢失,所以禁止。
allowRecreate := reuseCascadeID == ""
cascadeID, err = l.SendUserCascadeMessage(ctx, token, cascadeID, userText, modelUID, toolPreamble, modelEnumHint, images, allowRecreate)
if err != nil {
return nil, fmt.Errorf("SendUserCascadeMessage: %w", err)
}
const (
maxWait = 180 * time.Second
idleGrace = 8 * time.Second
pollInterval = 250 * time.Millisecond
noGrowthStallMs = 25000
stallRetryMinLen = 300
)
textCursors := make(map[int]int)
thinkCursors := make(map[int]int)
var totalText, totalThinking int
var accText, accThinking string
var firstTextAt time.Time
idleCount := 0
sawActive := false
sawText := false
lastGrowthAt := time.Now()
// Native tool call tracking
seenToolCalls := make(map[string]bool)
var nativeToolCalls []NativeToolCall
lastStatus := 0
startTime := time.Now()
deadline := startTime.Add(maxWait)
graceEnd := startTime.Add(idleGrace)
inputChars := len(userText)
// Aggregated step usage
usageByStep := make(map[int]*StepUsage)
for time.Now().Before(deadline) {
select {
case <-ctx.Done():
return &CascadeChatResult{Text: SanitizePath(accText), Thinking: accThinking, CascadeID: cascadeID, FirstTextAt: firstTextAt, ToolCalls: nativeToolCalls}, ctx.Err()
default:
}
time.Sleep(pollInterval)
steps, err := l.GetTrajectorySteps(ctx, cascadeID, startStepIndex)
if err != nil {
continue
}
// Check for error steps (type=17)
for _, s := range steps {
if s.Type == 17 && s.ErrorText != "" {
return nil, &CascadeModelError{Msg: s.ErrorText}
}
}
// Cold stall: active but no text/thinking after threshold
elapsed := time.Since(startTime)
coldThreshold := 30*time.Second + time.Duration(inputChars/1500)*5*time.Second
if coldThreshold > maxWait {
coldThreshold = maxWait
}
if elapsed > coldThreshold && sawActive && !sawText && totalThinking == 0 {
return nil, &CascadeModelError{Msg: fmt.Sprintf("Cascade planner stalled — no output after %ds", int(coldThreshold.Seconds()))}
}
for idx, s := range steps {
// Usage
if s.Usage != nil {
usageByStep[idx] = s.Usage
}
// Thinking delta
if s.Thinking != "" {
prev := thinkCursors[idx]
if len(s.Thinking) > prev {
accThinking += s.Thinking[prev:]
totalThinking += len(s.Thinking) - prev
thinkCursors[idx] = len(s.Thinking)
lastGrowthAt = time.Now()
}
}
// Native tool call from structured step data
if s.ToolCall != nil && s.ToolCall.Name != "" {
key := s.ToolCall.Name + "|" + s.ToolCall.ID
if !seenToolCalls[key] {
seenToolCalls[key] = true
nativeToolCalls = append(nativeToolCalls, *s.ToolCall)
lastGrowthAt = time.Now()
sawText = true
if firstTextAt.IsZero() {
firstTextAt = lastGrowthAt
}
}
}
// Text delta — use ResponseText during streaming for monotonic cursor
liveText := s.ResponseText
if liveText == "" {
liveText = s.Text
}
if liveText == "" {
continue
}
prev := textCursors[idx]
if len(liveText) > prev {
accText += liveText[prev:]
totalText += len(liveText) - prev
textCursors[idx] = len(liveText)
lastGrowthAt = time.Now()
if !sawText {
firstTextAt = lastGrowthAt
}
sawText = true
}
}
// Warm stall: text stopped growing for 25s while planner is active
if sawText && lastStatus != 1 && time.Since(lastGrowthAt).Milliseconds() > noGrowthStallMs {
if totalText < stallRetryMinLen {
return nil, &CascadeModelError{Msg: "Cascade planner stalled after preamble — no progress for 25s"}
}
break // accept partial result
}
status, err := l.GetTrajectoryStatus(ctx, cascadeID)
if err != nil {
continue
}
lastStatus = status
if status != 1 {
sawActive = true
}
if status == 1 { // IDLE
if !sawActive && time.Now().Before(graceEnd) {
continue
}
idleCount++
growthSettled := time.Since(lastGrowthAt) > pollInterval*2
canBreak := false
if sawText {
canBreak = idleCount >= 2 && growthSettled
} else {
canBreak = idleCount >= 4
}
if canBreak {
// Final sweep: fetch one more time to get modifiedText top-up
finalSteps, err := l.GetTrajectorySteps(ctx, cascadeID, startStepIndex)
if err == nil {
for idx, s := range finalSteps {
if s.Usage != nil {
usageByStep[idx] = s.Usage
}
// Top up from responseText
rt := s.ResponseText
if rt == "" {
rt = s.Text
}
prev := textCursors[idx]
if len(rt) > prev {
accText += rt[prev:]
totalText += len(rt) - prev
textCursors[idx] = len(rt)
}
// Modified-response top-up: only if it extends what we already emitted
mt := s.Text // Text = modifiedText || responseText
cursor := textCursors[idx]
if len(mt) > cursor && strings.HasPrefix(mt, rt) {
accText += mt[cursor:]
totalText += len(mt) - cursor
textCursors[idx] = len(mt)
}
// Thinking final sweep
if s.Thinking != "" {
prev := thinkCursors[idx]
if len(s.Thinking) > prev {
accThinking += s.Thinking[prev:]
totalThinking += len(s.Thinking) - prev
thinkCursors[idx] = len(s.Thinking)
}
}
}
}
break
}
} else {
idleCount = 0
}
}
slog.Info("windsurf_cascade_poll_result",
"cascade_id", cascadeID[:min(8, len(cascadeID))],
"acc_text_len", len(accText),
"acc_thinking_len", len(accThinking),
"native_tool_calls", len(nativeToolCalls),
"saw_active", sawActive,
"saw_text", sawText,
"steps_seen", len(textCursors),
"idle_count", idleCount,
)
// Aggregate step usage
var aggUsage *StepUsage
for _, u := range usageByStep {
if aggUsage == nil {
aggUsage = &StepUsage{}
}
aggUsage.InputTokens += u.InputTokens
aggUsage.OutputTokens += u.OutputTokens
aggUsage.CacheReadTokens += u.CacheReadTokens
aggUsage.CacheWriteTokens += u.CacheWriteTokens
}
return &CascadeChatResult{
Text: SanitizePath(accText),
Thinking: accThinking,
Usage: aggUsage,
CascadeID: cascadeID,
FirstTextAt: firstTextAt,
ToolCalls: nativeToolCalls,
}, nil
}
// ── gRPC helpers ───────────────────────────────────────────
func (l *LocalLSClient) grpcUnary(ctx context.Context, path string, body []byte) error {
_, err := l.grpcUnaryRaw(ctx, path, body)
return err
}
func (l *LocalLSClient) grpcUnaryRaw(ctx context.Context, path string, body []byte) ([]byte, error) {
env := make([]byte, 5+len(body))
env[0] = 0
binary.BigEndian.PutUint32(env[1:5], uint32(len(body)))
copy(env[5:], body)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, l.BaseURL+path, bytes.NewReader(env))
if err != nil {
return nil, fmt.Errorf("build request: %w", err)
}
req.Header.Set("Content-Type", "application/grpc")
req.Header.Set("TE", "trailers")
req.Header.Set("User-Agent", "grpc-go/1.64.0")
if l.CSRFToken != "" {
req.Header.Set("x-codeium-csrf-token", l.CSRFToken)
}
slog.Debug("windsurf_grpc_request", "url", l.BaseURL+path, "csrf_token", l.CSRFToken)
resp, err := l.HTTP.Do(req)
if err != nil {
return nil, fmt.Errorf("roundtrip: %w", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, truncate(string(respBody), 200))
}
grpcStatus := resp.Header.Get("grpc-status")
grpcMsg := resp.Header.Get("grpc-message")
if grpcStatus == "" {
grpcStatus = resp.Trailer.Get("grpc-status")
grpcMsg = resp.Trailer.Get("grpc-message")
}
if grpcStatus != "" && grpcStatus != "0" {
slog.Warn("windsurf_grpc_error",
"url", l.BaseURL+path,
"grpc_status", grpcStatus,
"grpc_msg", grpcMsg,
"http_status", resp.StatusCode,
"resp_headers", fmt.Sprintf("%v", resp.Header),
"resp_trailers", fmt.Sprintf("%v", resp.Trailer),
"body_len", len(respBody),
)
decoded, decErr := url.QueryUnescape(grpcMsg)
if decErr == nil {
grpcMsg = decoded
}
return nil, fmt.Errorf("gRPC status %s: %s", grpcStatus, grpcMsg)
}
return stripGRPCFrame(respBody), nil
}
func stripGRPCFrame(data []byte) []byte {
if len(data) < 5 {
return data
}
msgLen := binary.BigEndian.Uint32(data[1:5])
if 5+int(msgLen) <= len(data) {
return data[5 : 5+msgLen]
}
return data[5:]
}
// ── Model enum mapping ─────────────────────────────────────
// modelEnumByUID maps modelUid strings to their deprecated enum values.
// Only entries with enumValue > 0 are included. Sourced from WindsurfAPI models.js.
var modelEnumByUID = map[string]int{
// Anthropic
"MODEL_CLAUDE_4_SONNET": 281,
"MODEL_CLAUDE_4_SONNET_THINKING": 282,
"MODEL_CLAUDE_4_OPUS": 290,
"MODEL_CLAUDE_4_OPUS_THINKING": 291,
"MODEL_CLAUDE_4_1_OPUS": 328,
"MODEL_CLAUDE_4_1_OPUS_THINKING": 329,
"MODEL_PRIVATE_2": 353,
"MODEL_PRIVATE_3": 354,
"MODEL_CLAUDE_4_5_OPUS": 391,
"MODEL_CLAUDE_4_5_OPUS_THINKING": 392,
// OpenAI
"MODEL_CHAT_GPT_4O_2024_08_06": 109,
"MODEL_CHAT_GPT_4_1_2025_04_14": 259,
"MODEL_PRIVATE_6": 340,
"MODEL_CHAT_GPT_5_CODEX": 346,
"MODEL_GPT_5_2_LOW": 400,
"MODEL_GPT_5_2_MEDIUM": 401,
"MODEL_GPT_5_2_HIGH": 402,
"MODEL_GPT_5_2_XHIGH": 403,
"MODEL_CHAT_O3": 218,
// Google
"MODEL_GOOGLE_GEMINI_2_5_PRO": 246,
"MODEL_GOOGLE_GEMINI_2_5_FLASH": 312,
"MODEL_GOOGLE_GEMINI_3_0_PRO_LOW": 412,
"MODEL_GOOGLE_GEMINI_3_0_FLASH_MEDIUM": 415,
// Others
"MODEL_XAI_GROK_3": 217,
"MODEL_KIMI_K2": 323,
"MODEL_GLM_4_7": 417,
"MODEL_SWE_1_5_SLOW": 369,
"MODEL_SWE_1_5": 359,
}
func resolveModelEnum(modelUID string) int {
if v, ok := modelEnumByUID[modelUID]; ok {
return v
}
return 0
}
// ── Proto parsers ──────────────────────────────────────────
func parseStringField1(data []byte) (string, error) {
pos := 0
for pos < len(data) {
tag, np, ok := ReadVarint(data, pos)
if !ok {
break
}
pos = np
fieldNum := tag >> 3
wireType := tag & 7
switch wireType {
case 2:
length, np2, ok := ReadVarint(data, pos)
if !ok {
return "", fmt.Errorf("parse length at pos %d", pos)
}
pos = np2
if pos+int(length) > len(data) {
return "", fmt.Errorf("field out of bounds")
}
field := data[pos : pos+int(length)]
pos += int(length)
if fieldNum == 1 {
return string(field), nil
}
case 0:
_, np2, ok := ReadVarint(data, pos)
if !ok {
break
}
pos = np2
case 1:
pos += 8
case 5:
pos += 4
default:
return "", fmt.Errorf("unknown wire type %d at pos %d", wireType, pos)
}
}
return "", nil
}
func parseVarintField2(data []byte) (uint64, error) {
pos := 0
for pos < len(data) {
tag, np, ok := ReadVarint(data, pos)
if !ok {
break
}
pos = np
fieldNum := tag >> 3
wireType := tag & 7
switch wireType {
case 0:
val, np2, ok := ReadVarint(data, pos)
if !ok {
break
}
pos = np2
if fieldNum == 2 {
return val, nil
}
case 2:
length, np2, ok := ReadVarint(data, pos)
if !ok {
break
}
pos = np2 + int(length)
case 1:
pos += 8
case 5:
pos += 4
default:
break
}
}
return 0, nil
}
func parseTrajectorySteps(data []byte) []TrajectoryStep {
var steps []TrajectoryStep
pos := 0
for pos < len(data) {
tag, np, ok := ReadVarint(data, pos)
if !ok {
break
}
pos = np
fieldNum := tag >> 3
wireType := tag & 7
if wireType == 2 {
length, np2, ok := ReadVarint(data, pos)
if !ok {
break
}
pos = np2
if pos+int(length) > len(data) {
break
}
field := data[pos : pos+int(length)]
pos += int(length)
if fieldNum == 1 {
steps = append(steps, parseOneTrajectoryStep(field))
}
} else if wireType == 0 {
_, np2, ok := ReadVarint(data, pos)
if !ok {
break
}
pos = np2
} else if wireType == 1 {
pos += 8
} else if wireType == 5 {
pos += 4
} else {
break
}
}
return steps
}
func parseOneTrajectoryStep(data []byte) TrajectoryStep {
var s TrajectoryStep
pos := 0
for pos < len(data) {
tag, np, ok := ReadVarint(data, pos)
if !ok {
break
}
pos = np
fieldNum := tag >> 3
wireType := tag & 7
switch wireType {
case 0:
val, np2, ok := ReadVarint(data, pos)
if !ok {
break
}
pos = np2
switch fieldNum {
case 1:
s.Type = int(val)
case 4:
s.Status = int(val)
}
case 2:
length, np2, ok := ReadVarint(data, pos)
if !ok {
break
}
pos = np2
if pos+int(length) > len(data) {
break
}
field := data[pos : pos+int(length)]
pos += int(length)
switch fieldNum {
case 5: // step metadata (CortexStepMetadata)
s.Usage = parseStepUsage(field)
if tc := parseMetadataToolCall(field); tc != nil {
s.ToolCall = tc
}
case 47: // mcp_tool (CortexStepMcpTool) — tool_call is field 2
if tc := parseChatToolCallFromContainer(field, 2); tc != nil {
s.ToolCall = tc
}
case 49: // tool_call_proposal (CortexStepToolCallProposal) — tool_call is field 1
if tc := parseChatToolCallFromContainer(field, 1); tc != nil {
s.ToolCall = tc
}
case 20: // planner_response
pr := parseFields2(field)
var responseText, modifiedText, thinking string
for _, pf := range pr {
switch pf.fn {
case 1:
responseText = string(pf.val)
case 3:
thinking = string(pf.val)
case 8:
modifiedText = string(pf.val)
}
}
if modifiedText != "" {
s.Text = modifiedText
} else {
s.Text = responseText
}
s.ResponseText = responseText
s.Thinking = thinking
case 24: // error_message
s.ErrorText = extractErrorText(field)
case 31: // error (fallback)
if s.ErrorText == "" {
s.ErrorText = extractErrorText(field)
}
}
case 1:
pos += 8
case 5:
pos += 4
default:
pos = len(data)
}
}
return s
}
// parseChatToolCall parses a ChatToolCall proto message:
//
// field 1 (string) = id
// field 2 (string) = name
// field 3 (string) = arguments_json
func parseChatToolCall(data []byte) *NativeToolCall {
var tc NativeToolCall
pos := 0
for pos < len(data) {
tag, np, ok := ReadVarint(data, pos)
if !ok {
break
}
pos = np
fieldNum := tag >> 3
wireType := tag & 7
if wireType == 2 {
length, np2, ok := ReadVarint(data, pos)
if !ok {
break
}
pos = np2
if pos+int(length) > len(data) {
break
}
val := string(data[pos : pos+int(length)])
pos += int(length)
switch fieldNum {
case 1:
tc.ID = val
case 2:
tc.Name = val
case 3:
tc.ArgumentsJSON = val
}
} else if wireType == 0 {
_, np2, ok := ReadVarint(data, pos)
if !ok {
break
}
pos = np2
} else if wireType == 1 {
pos += 8
} else if wireType == 5 {
pos += 4
} else {
break
}
}
if tc.Name == "" {
return nil
}
return &tc
}
// parseChatToolCallFromContainer extracts ChatToolCall from a container message
// where the ChatToolCall is at the given field number.
func parseChatToolCallFromContainer(data []byte, toolCallFieldNum uint64) *NativeToolCall {
pos := 0
for pos < len(data) {
tag, np, ok := ReadVarint(data, pos)
if !ok {
break
}
pos = np
fieldNum := tag >> 3
wireType := tag & 7
if wireType == 2 {
length, np2, ok := ReadVarint(data, pos)
if !ok {
break
}
pos = np2
if pos+int(length) > len(data) {
break
}
field := data[pos : pos+int(length)]
pos += int(length)
if fieldNum == toolCallFieldNum {
return parseChatToolCall(field)
}
} else if wireType == 0 {
_, np2, ok := ReadVarint(data, pos)
if !ok {
break
}
pos = np2
} else if wireType == 1 {
pos += 8
} else if wireType == 5 {
pos += 4
} else {
break
}
}
return nil
}
// parseMetadataToolCall extracts ChatToolCall from CortexStepMetadata (field 4 = tool_call).
func parseMetadataToolCall(metaData []byte) *NativeToolCall {
return parseChatToolCallFromContainer(metaData, 4)
}
// parseStepUsage extracts token usage from CortexStepMetadata (field 5).
// CortexStepMetadata.model_usage = field 9 → ModelUsageStats {2=input, 3=output, 4=cacheWrite, 5=cacheRead}
func parseStepUsage(metaData []byte) *StepUsage {
// Find field 9 (model_usage) in metadata
usageData := extractLenDelimField(metaData, 9)
if usageData == nil {
return nil
}
var u StepUsage
found := false
pos := 0
for pos < len(usageData) {
tag, np, ok := ReadVarint(usageData, pos)
if !ok {
break
}
pos = np
fn := tag >> 3
wt := tag & 7
if wt == 0 {
val, np2, ok := ReadVarint(usageData, pos)
if !ok {
break
}
pos = np2
switch fn {
case 2:
u.InputTokens = int(val)
found = true
case 3:
u.OutputTokens = int(val)
found = true
case 4:
u.CacheWriteTokens = int(val)
found = true
case 5:
u.CacheReadTokens = int(val)
found = true
}
} else if wt == 2 {
length, np2, ok := ReadVarint(usageData, pos)
if !ok {
break
}
pos = np2 + int(length)
} else if wt == 1 {
pos += 8
} else if wt == 5 {
pos += 4
} else {
break
}
}
if !found {
return nil
}
return &u
}
// extractLenDelimField finds the first length-delimited field with the given number.
func extractLenDelimField(data []byte, targetField uint64) []byte {
pos := 0
for pos < len(data) {
tag, np, ok := ReadVarint(data, pos)
if !ok {
break
}
pos = np
fn := tag >> 3
wt := tag & 7
if wt == 2 {
length, np2, ok := ReadVarint(data, pos)
if !ok {
break
}
pos = np2
if pos+int(length) > len(data) {
break
}
if fn == targetField {
return data[pos : pos+int(length)]
}
pos += int(length)
} else if wt == 0 {
_, np2, ok := ReadVarint(data, pos)
if !ok {
break
}
pos = np2
} else if wt == 1 {
pos += 8
} else if wt == 5 {
pos += 4
} else {
break
}
}
return nil
}
type protoField struct {
fn uint64
val []byte
}
func parseFields2(data []byte) []protoField {
var fields []protoField
pos := 0
for pos < len(data) {
tag, np, ok := ReadVarint(data, pos)
if !ok {
break
}
pos = np
fn := tag >> 3
wt := tag & 7
switch wt {
case 2:
length, np2, ok := ReadVarint(data, pos)
if !ok {
break
}
pos = np2
if pos+int(length) > len(data) {
break
}
fields = append(fields, protoField{fn, data[pos : pos+int(length)]})
pos += int(length)
case 0:
_, np2, ok := ReadVarint(data, pos)
if !ok {
break
}
pos = np2
case 1:
pos += 8
case 5:
pos += 4
default:
pos = len(data)
}
}
return fields
}
func extractErrorText(data []byte) string {
return extractErrorTextDepth(data, 0)
}
func extractErrorTextDepth(data []byte, depth int) string {
if depth > 3 {
return ""
}
for _, pf := range parseFields2(data) {
if pf.fn >= 1 && pf.fn <= 5 && len(pf.val) > 10 {
txt := string(pf.val)
for len(txt) > 0 && txt[0] < 0x20 {
txt = txt[1:]
}
if len(txt) > 10 && !hasNonPrintable(txt[:10]) {
return txt
}
if inner := extractErrorTextDepth(pf.val, depth+1); inner != "" {
return inner
}
}
}
return ""
}
func hasNonPrintable(s string) bool {
for _, c := range s {
if c < 0x20 && c != '\n' && c != '\r' {
return true
}
}
return false
}
// GetCascadeModelConfigs 查询 LS 的 GetCascadeModelConfigs RPC
// 返回 model_name -> supports_images 的映射。模型名按小写归一化。
func (l *LocalLSClient) GetCascadeModelConfigs(ctx context.Context, token string) (map[string]bool, error) {
// 请求 body只需 metadata即 field 1 encode(Metadata)
// 参考 package.json 提到的 proto这里用 metadata-only encoding 与其他 RPC 一致。
body := encodeBytesField(1, buildMetadata(token, l.SessionID))
raw, err := l.grpcUnaryRaw(ctx, GetCascadeModelConfigsRPC, body)
if err != nil {
return nil, fmt.Errorf("get_cascade_model_configs: %w", err)
}
var resp pb.GetCascadeModelConfigsResponse
if err := proto.Unmarshal(raw, &resp); err != nil {
return nil, fmt.Errorf("unmarshal: %w", err)
}
out := make(map[string]bool, len(resp.GetModels()))
for _, m := range resp.GetModels() {
out[strings.ToLower(strings.TrimSpace(m.GetName()))] = m.GetSupportsImages()
}
return out, nil
}
// ModelSupportsImages 带缓存的图像能力查询。
// fail-openRPC 失败且无缓存时返回 (false, false, nil),由上层决定策略。
// 返回值:(found, supportsImages, error)
func (l *LocalLSClient) ModelSupportsImages(ctx context.Context, token, modelName string) (bool, bool, error) {
key := apiKeyHash(token)
l.modelCapsMu.Lock()
if l.modelCapsCache == nil {
l.modelCapsCache = make(map[string]cascadeModelCapsCacheEntry)
}
entry, ok := l.modelCapsCache[key]
fresh := ok && time.Since(entry.FetchedAt) < cascadeModelCapsTTL
l.modelCapsMu.Unlock()
if fresh {
v, found := entry.SupportsImages[strings.ToLower(strings.TrimSpace(modelName))]
return found, v, nil
}
// 拉新:失败时保留 stale
caps, err := l.GetCascadeModelConfigs(ctx, token)
if err != nil {
// stale fallback
if ok {
v, found := entry.SupportsImages[strings.ToLower(strings.TrimSpace(modelName))]
return found, v, nil
}
return false, false, err
}
l.modelCapsMu.Lock()
l.modelCapsCache[key] = cascadeModelCapsCacheEntry{
SupportsImages: caps,
FetchedAt: time.Now(),
}
l.modelCapsMu.Unlock()
v, found := caps[strings.ToLower(strings.TrimSpace(modelName))]
return found, v, nil
}
func apiKeyHash(token string) string {
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:8]) // 16 hex chars 足够区分
}