502 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// test_windsurf_minimal validates the Windsurf Cascade chat flow end-to-end:
//
// 1. JWT decode (local)
// 2. GetUserStatus (resolve user_id/team_id)
// 3. CheckChatCapacity
// 4. GetCascadeModelConfigs (pick cheapest non-BYOK model)
// 5. CascadeChat via local LS:
// a. WarmupCascade (InitializeCascadePanelState + AddTrackedWorkspace + UpdateWorkspaceTrust)
// b. StartCascade → cascade_id
// c. SendUserCascadeMessage
// d. Poll GetCascadeTrajectorySteps until IDLE
// 6. Completeness check (non-empty text)
//
// Usage:
//
// WINDSURF_JWT="devin-session-token$xxx.yyy.zzz" \
// WINDSURF_CSRF_TOKEN="..." \
// go run ./cmd/test_windsurf_minimal -verbose
package main
import (
"context"
"flag"
"fmt"
"os"
"os/exec"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/windsurf"
)
type cliFlags struct {
jwt string
baseURL string
model string
prompt string
proxy string
verbose bool
timeout time.Duration
userID string
teamID string
csrfToken string
lsPort int
}
func parseFlags() cliFlags {
var f cliFlags
flag.StringVar(&f.jwt, "jwt", os.Getenv("WINDSURF_JWT"),
"full session token (e.g. devin-session-token$eyJ...). Defaults to $WINDSURF_JWT")
flag.StringVar(&f.baseURL, "base-url", envOr("WINDSURF_BASE_URL", windsurf.DefaultBaseURL),
"upstream base URL")
flag.StringVar(&f.model, "model", "",
"modelUid to use (e.g. claude-opus-4-7-medium); empty = pick cheapest from ListModels")
flag.StringVar(&f.prompt, "prompt", "Say hello in 3 words.",
"user prompt")
flag.StringVar(&f.proxy, "proxy", os.Getenv("HTTPS_PROXY"),
"optional HTTP proxy URL (mitm capture)")
flag.BoolVar(&f.verbose, "verbose", false, "print extra dump info")
flag.DurationVar(&f.timeout, "timeout", 90*time.Second, "per-step timeout")
flag.StringVar(&f.userID, "user-id", os.Getenv("WINDSURF_USER_ID"),
"metadata F20 user-XXX (from userStatus proto)")
flag.StringVar(&f.teamID, "team-id", os.Getenv("WINDSURF_TEAM_ID"),
"metadata F32 devin-team$account-XXX (from userStatus proto)")
flag.StringVar(&f.csrfToken, "csrf-token", os.Getenv("WINDSURF_CSRF_TOKEN"),
"x-codeium-csrf-token header value (WINDSURF_CSRF_TOKEN env or from LS process args)")
flag.IntVar(&f.lsPort, "ls-port", envInt("WINDSURF_LS_PORT", 0),
"local LanguageServerService gRPC port (0 = auto-detect)")
flag.Parse()
return f
}
func envOr(k, def string) string {
if v := os.Getenv(k); v != "" {
return v
}
return def
}
func envInt(k string, def int) int {
if v := os.Getenv(k); v != "" {
var n int
if _, err := fmt.Sscanf(v, "%d", &n); err == nil {
return n
}
}
return def
}
type stepResult struct {
name string
ok bool
detail string
elapsed time.Duration
}
func main() {
f := parseFlags()
if strings.TrimSpace(f.jwt) == "" {
fmt.Fprintln(os.Stderr, "ERROR: -jwt or WINDSURF_JWT required (full token incl. devin-session-token$ prefix)")
flag.Usage()
os.Exit(2)
}
client, err := windsurf.NewClient(f.baseURL, f.proxy, f.csrfToken)
if err != nil {
fmt.Fprintln(os.Stderr, "ERROR build client:", err)
os.Exit(2)
}
// Auto-detect CSRF token if not provided
if f.csrfToken == "" {
f.csrfToken = detectLSCSRF()
if f.verbose && f.csrfToken != "" {
fmt.Fprintf(os.Stderr, " auto-detected CSRF token: %s\n", f.csrfToken[:8]+"...")
}
}
results := make([]stepResult, 0, 8)
pickedModel := f.model
userID := f.userID
teamID := f.teamID
// ── Step 1: JWT decode ────────────────────────────────────────────────
{
t0 := time.Now()
claims, err := windsurf.DecodeJWTClaims(f.jwt)
el := time.Since(t0)
if err != nil {
results = append(results, stepResult{"JWT 解码", false, err.Error(), el})
printResults(results)
os.Exit(1)
}
now := time.Now().Unix()
expStr := "(no exp)"
expired := false
if claims.Exp > 0 {
expStr = time.Unix(claims.Exp, 0).Format(time.RFC3339)
if claims.Exp <= now {
expired = true
}
}
if userID == "" {
userID = claims.UserID
}
if teamID == "" {
teamID = claims.TeamID
}
detail := fmt.Sprintf("session_id=%s user_id=%s team_id=%s exp=%s",
elide(claims.SessionID, 20), claims.UserID, claims.TeamID, expStr)
if expired {
results = append(results, stepResult{"JWT 解码", false, detail + " (EXPIRED)", el})
printResults(results)
os.Exit(1)
}
results = append(results, stepResult{"JWT 解码", true, detail, el})
}
// ── Step 2: GetUserStatus ─────────────────────────────────────────────
if userID == "" || teamID == "" {
t0 := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), f.timeout)
us, err := client.GetUserStatus(ctx, f.jwt)
cancel()
el := time.Since(t0)
if err != nil {
results = append(results, stepResult{"GetUserStatus", false, err.Error(), el})
printResults(results)
os.Exit(1)
}
if userID == "" {
userID = us.UserID
}
if teamID == "" {
teamID = us.TeamID
}
detail := fmt.Sprintf("user_id=%s team_id=%s", elide(userID, 30), elide(teamID, 40))
results = append(results, stepResult{"GetUserStatus", true, detail, el})
}
// ── Step 3: CheckChatCapacity ─────────────────────────────────────────
{
t0 := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), f.timeout)
hasCap, raw, err := client.CheckChatCapacity(ctx, f.jwt)
cancel()
el := time.Since(t0)
if err != nil {
results = append(results, stepResult{"CheckChatCapacity", false, err.Error(), el})
printResults(results)
os.Exit(1)
}
detail := fmt.Sprintf("hasCapacity=%v raw=%s", hasCap, raw)
results = append(results, stepResult{"CheckChatCapacity", hasCap, detail, el})
if !hasCap {
printResults(results)
os.Exit(1)
}
}
// ── Step 4: List models ───────────────────────────────────────────────
{
t0 := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), f.timeout)
models, err := client.ListModels(ctx, f.jwt)
cancel()
el := time.Since(t0)
if err != nil {
results = append(results, stepResult{"GetCascadeModelConfigs", false, err.Error(), el})
printResults(results)
os.Exit(1)
}
if len(models) == 0 {
results = append(results, stepResult{"GetCascadeModelConfigs", false, "no models returned", el})
printResults(results)
os.Exit(1)
}
if pickedModel == "" {
pickedModel = pickCheapest(models)
} else if !windsurf.HasModel(models, pickedModel) {
results = append(results, stepResult{"GetCascadeModelConfigs", false,
fmt.Sprintf("requested model %q not in catalog", pickedModel), el})
printResults(results)
os.Exit(1)
}
detail := fmt.Sprintf("got %d models, picked: %s", len(models), pickedModel)
if f.verbose {
detail += "\n Top 5 by multiplier:"
for i, m := range topNCheapest(models, 5) {
detail += fmt.Sprintf("\n [%d] %-40s ×%-5g %s", i+1, m.ModelUID, m.CreditMultiplier, m.Label)
}
}
results = append(results, stepResult{"GetCascadeModelConfigs", true, detail, el})
}
// ── Step 5: Cascade chat via local LS ────────────────────────────────
finalText := ""
{
t0 := time.Now()
lsPort := f.lsPort
if lsPort == 0 {
lsPort = detectLSPort()
}
if lsPort == 0 {
results = append(results, stepResult{"CascadeChat", false,
"no local LS port found; set WINDSURF_LS_PORT or -ls-port", time.Since(t0)})
printResults(results)
os.Exit(1)
}
lsClient := windsurf.NewLocalLSClient(lsPort, f.csrfToken)
// Warmup
{
ctx, cancel := context.WithTimeout(context.Background(), f.timeout)
_ = lsClient.WarmupCascade(ctx, f.jwt)
cancel()
results = append(results, stepResult{"WarmupCascade", true,
fmt.Sprintf("ls_port=%d session=%s", lsPort, lsClient.SessionID[:8]), time.Since(t0)})
}
// StartCascade
var cascadeID string
{
ctx, cancel := context.WithTimeout(context.Background(), f.timeout)
cid, err := lsClient.StartCascade(ctx, f.jwt)
cancel()
if err != nil {
results = append(results, stepResult{"StartCascade", false, err.Error(), time.Since(t0)})
printResults(results)
os.Exit(1)
}
cascadeID = cid
results = append(results, stepResult{"StartCascade", true,
fmt.Sprintf("cascade_id=%s", cid), time.Since(t0)})
}
// SendUserCascadeMessage
{
ctx, cancel := context.WithTimeout(context.Background(), f.timeout)
newCID, err := lsClient.SendUserCascadeMessage(ctx, f.jwt, cascadeID, f.prompt, pickedModel, "", 0, nil, true)
if err == nil && newCID != "" {
cascadeID = newCID
}
cancel()
if err != nil {
results = append(results, stepResult{"SendCascadeMsg", false, err.Error(), time.Since(t0)})
printResults(results)
os.Exit(1)
}
results = append(results, stepResult{"SendCascadeMsg", true,
fmt.Sprintf("model=%s prompt_len=%d", pickedModel, len(f.prompt)), time.Since(t0)})
}
// Poll trajectory steps until IDLE
t0Chat := time.Now()
ttft := time.Duration(0)
firstText := true
seenSteps := 0
deadline := time.Now().Add(f.timeout)
sawActive := false
graceEnd := time.Now().Add(8 * time.Second)
idleCount := 0
for time.Now().Before(deadline) {
time.Sleep(500 * time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
steps, err := lsClient.GetTrajectorySteps(ctx, cascadeID, 0)
cancel()
if err != nil {
if f.verbose {
fmt.Fprintf(os.Stderr, " GetTrajectorySteps err: %v\n", err)
}
continue
}
for idx, s := range steps {
if s.Text == "" {
continue
}
if idx >= seenSteps {
if firstText {
ttft = time.Since(t0Chat)
firstText = false
}
if s.Type == 17 { // error step
if f.verbose {
fmt.Fprintf(os.Stderr, " error step[%d]: %s\n", idx, elide(s.Text, 100))
}
if strings.Contains(s.Text, "rate limit") {
finalText = "(rate-limited: " + elide(s.Text, 80) + ")"
}
} else {
finalText += s.Text
if f.verbose {
fmt.Fprintf(os.Stderr, " step[%d] type=%d status=%d text=%q\n",
idx, s.Type, s.Status, elide(s.Text, 60))
}
}
seenSteps = idx + 1
}
}
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
status, err := lsClient.GetTrajectoryStatus(ctx2, cascadeID)
cancel2()
if f.verbose {
fmt.Fprintf(os.Stderr, " trajectory status=%d err=%v steps_so_far=%d\n", status, err, seenSteps)
}
if err != nil {
continue
}
if status != 0 && status != 1 && status != 2 {
sawActive = true
}
if status == 1 || status == 2 { // IDLE
if !sawActive && time.Now().Before(graceEnd) {
continue
}
idleCount++
if (finalText != "" && idleCount >= 2) || (finalText == "" && idleCount >= 4) {
break
}
} else {
sawActive = true
idleCount = 0
}
}
el := time.Since(t0)
detail := fmt.Sprintf("steps=%d TTFT=%v text_len=%d", seenSteps, ttft.Round(time.Millisecond), len(finalText))
results = append(results, stepResult{"CascadeChat 轨迹", finalText != "", detail, el})
}
// ── Step 6: Completeness ──────────────────────────────────────────────
{
t0 := time.Now()
var problems []string
if strings.TrimSpace(finalText) == "" {
problems = append(problems, "empty text")
}
ok := len(problems) == 0
detail := "all checks passed"
if !ok {
detail = strings.Join(problems, ", ")
}
results = append(results, stepResult{"完整性校验", ok, detail, time.Since(t0)})
}
printResults(results)
if finalText != "" {
fmt.Println()
fmt.Println("─── 模型回复 ───")
fmt.Println(finalText)
}
if !allPassed(results) {
os.Exit(1)
}
}
func printResults(rs []stepResult) {
fmt.Println()
for i, r := range rs {
mark := "✅"
if !r.ok {
mark = "❌"
}
fmt.Printf("[%d/%d] %-26s %s %-7s %s\n", i+1, len(rs), r.name, mark, r.elapsed.Round(time.Millisecond), r.detail)
}
fmt.Println()
if allPassed(rs) {
fmt.Println("✅ 全部通过")
} else {
fmt.Println("❌ 有步骤失败")
}
}
func allPassed(rs []stepResult) bool {
if len(rs) == 0 {
return false
}
for _, r := range rs {
if !r.ok {
return false
}
}
return true
}
func elide(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "..."
}
func pickCheapest(models []windsurf.ModelInfo) string {
if len(models) == 0 {
return ""
}
best := models[0]
for _, m := range models[1:] {
if strings.Contains(strings.ToLower(m.ModelUID), "byok") {
continue
}
if m.CreditMultiplier > 0 && m.CreditMultiplier < best.CreditMultiplier {
best = m
}
}
return best.ModelUID
}
func topNCheapest(models []windsurf.ModelInfo, n int) []windsurf.ModelInfo {
cp := make([]windsurf.ModelInfo, 0, len(models))
for _, m := range models {
if strings.Contains(strings.ToLower(m.ModelUID), "byok") {
continue
}
cp = append(cp, m)
}
for i := 0; i < len(cp) && i < n; i++ {
minIdx := i
for j := i + 1; j < len(cp); j++ {
if cp[j].CreditMultiplier > 0 && cp[j].CreditMultiplier < cp[minIdx].CreditMultiplier {
minIdx = j
}
}
cp[i], cp[minIdx] = cp[minIdx], cp[i]
}
if len(cp) < n {
return cp
}
return cp[:n]
}
// detectLSPort finds the local Windsurf LS gRPC port using lsof.
func detectLSPort() int {
cmd := exec.Command("sh", "-c",
`pgrep -f 'Windsurf.app.*language_server' 2>/dev/null | xargs -I{} lsof -p {} 2>/dev/null | awk '/LISTEN/{print $9}' | grep -oE '[0-9]+$' | head -1`)
out, err := cmd.Output()
if err != nil || len(out) == 0 {
return 0
}
var port int
if _, err := fmt.Sscanf(strings.TrimSpace(string(out)), "%d", &port); err != nil {
return 0
}
return port
}
// detectLSCSRF finds the CSRF token for the Windsurf LS serving the current workspace.
func detectLSCSRF() string {
cmd := exec.Command("sh", "-c",
`pgrep -f 'Windsurf.app.*language_server' 2>/dev/null | while read pid; do grep -z WINDSURF_CSRF_TOKEN /proc/$pid/environ 2>/dev/null || ps eww -p $pid 2>/dev/null | tr ' ' '\n' | grep WINDSURF_CSRF_TOKEN; done | grep -oE '[0-9a-f-]{36}' | head -1`)
out, err := cmd.Output()
if err != nil || len(out) == 0 {
return ""
}
return strings.TrimSpace(string(out))
}