- windsurf: client/pool/local_ls/tool_emulation/tool_names/models 调整 - handler: admin account_data / failover_loop / gateway_handler - repository: scheduler_cache 及测试 - service: windsurf_chat_service / windsurf_gateway_service - deploy: compose 合并为单文件(含 windsurf-ls profile),Dockerfile.ls - cmd: 新增 dump_ls_models / dump_preamble / test_windsurf_tools 辅助工具
248 lines
9.3 KiB
Go
248 lines
9.3 KiB
Go
// test_windsurf_tools validates Cascade tool-calling end-to-end.
|
||
//
|
||
// Same flow as test_windsurf_minimal but injects an OpenAI-format tools[]
|
||
// preamble into SendUserCascadeMessage and parses <tool_call> blocks back
|
||
// out of the trajectory text.
|
||
//
|
||
// Usage:
|
||
//
|
||
// WINDSURF_JWT='devin-session-token$...' \
|
||
// WINDSURF_CSRF_TOKEN='ad2d9f01-...' \
|
||
// WINDSURF_USER_ID='devin-user$...' \
|
||
// WINDSURF_TEAM_ID='devin-team$account-...' \
|
||
// WINDSURF_LS_PORT=42099 \
|
||
// go run ./cmd/test_windsurf_tools -verbose
|
||
package main
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"flag"
|
||
"fmt"
|
||
"os"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/windsurf"
|
||
)
|
||
|
||
type cliFlags struct {
|
||
jwt string
|
||
model string
|
||
prompt string
|
||
verbose bool
|
||
timeout time.Duration
|
||
userID string
|
||
teamID string
|
||
csrfToken string
|
||
lsPort int
|
||
toolChoice string
|
||
roundtrip bool
|
||
}
|
||
|
||
func parseFlags() cliFlags {
|
||
var f cliFlags
|
||
flag.StringVar(&f.jwt, "jwt", os.Getenv("WINDSURF_JWT"), "session token")
|
||
flag.StringVar(&f.model, "model", os.Getenv("WINDSURF_MODEL"), "model UID (optional, auto-picks cheapest)")
|
||
flag.StringVar(&f.prompt, "prompt", "Find every Go file in backend/internal/pkg/windsurf whose name contains 'tool', then read the first 40 lines of tool_emulation.go. Use the tools.", "user prompt")
|
||
flag.BoolVar(&f.verbose, "verbose", false, "verbose")
|
||
flag.DurationVar(&f.timeout, "timeout", 90*time.Second, "per-step timeout")
|
||
flag.StringVar(&f.userID, "user-id", os.Getenv("WINDSURF_USER_ID"), "user id")
|
||
flag.StringVar(&f.teamID, "team-id", os.Getenv("WINDSURF_TEAM_ID"), "team id")
|
||
flag.StringVar(&f.csrfToken, "csrf", os.Getenv("WINDSURF_CSRF_TOKEN"), "LS CSRF token")
|
||
flag.IntVar(&f.lsPort, "ls-port", envInt("WINDSURF_LS_PORT", 42099), "LS port")
|
||
flag.StringVar(&f.toolChoice, "tool-choice", "auto", "auto | required | none | <tool_name>")
|
||
flag.BoolVar(&f.roundtrip, "roundtrip", false, "after first turn, inject fake tool_result and test Turn 2")
|
||
flag.Parse()
|
||
return f
|
||
}
|
||
|
||
func envInt(k string, dflt int) int {
|
||
v := os.Getenv(k)
|
||
if v == "" {
|
||
return dflt
|
||
}
|
||
var n int
|
||
fmt.Sscanf(v, "%d", &n)
|
||
if n == 0 {
|
||
return dflt
|
||
}
|
||
return n
|
||
}
|
||
|
||
func main() {
|
||
f := parseFlags()
|
||
if f.jwt == "" || f.csrfToken == "" || f.userID == "" || f.teamID == "" {
|
||
fmt.Fprintln(os.Stderr, "missing WINDSURF_JWT / CSRF / USER_ID / TEAM_ID")
|
||
os.Exit(2)
|
||
}
|
||
|
||
// Build tools[] — realistic coding tools: read_file, find_file, grep, list_dir
|
||
tools := []windsurf.OpenAITool{
|
||
{Type: "function", Function: windsurf.OpenAIFunction{
|
||
Name: "read_file",
|
||
Description: "Read the contents of a file. Use when you need to see what's inside a specific file.",
|
||
Parameters: json.RawMessage(`{"type":"object","properties":{
|
||
"path":{"type":"string","description":"Absolute or repo-relative file path"},
|
||
"start_line":{"type":"integer","description":"Optional 1-indexed start line","minimum":1},
|
||
"end_line":{"type":"integer","description":"Optional 1-indexed inclusive end line","minimum":1}
|
||
},"required":["path"]}`),
|
||
}},
|
||
{Type: "function", Function: windsurf.OpenAIFunction{
|
||
Name: "find_file",
|
||
Description: "Find files by glob pattern. Use when looking for files whose path matches a pattern.",
|
||
Parameters: json.RawMessage(`{"type":"object","properties":{
|
||
"pattern":{"type":"string","description":"Glob pattern, e.g. **/*.go or src/**/test_*.py"},
|
||
"max_results":{"type":"integer","default":50}
|
||
},"required":["pattern"]}`),
|
||
}},
|
||
{Type: "function", Function: windsurf.OpenAIFunction{
|
||
Name: "grep",
|
||
Description: "Search file contents by regex. Use when looking for code that matches a text pattern.",
|
||
Parameters: json.RawMessage(`{"type":"object","properties":{
|
||
"regex":{"type":"string","description":"POSIX/PCRE regex"},
|
||
"path_glob":{"type":"string","description":"Optional path glob filter, e.g. **/*.ts"},
|
||
"case_insensitive":{"type":"boolean","default":false}
|
||
},"required":["regex"]}`),
|
||
}},
|
||
{Type: "function", Function: windsurf.OpenAIFunction{
|
||
Name: "list_dir",
|
||
Description: "List files and sub-directories at a path. Use for shallow directory exploration.",
|
||
Parameters: json.RawMessage(`{"type":"object","properties":{
|
||
"path":{"type":"string","description":"Directory path"}
|
||
},"required":["path"]}`),
|
||
}},
|
||
}
|
||
// Resolve tool_choice: "auto" | "required" | "none" | tool_name → object
|
||
var toolChoice interface{} = f.toolChoice
|
||
if f.toolChoice != "auto" && f.toolChoice != "required" && f.toolChoice != "none" {
|
||
toolChoice = map[string]any{"type": "function", "function": map[string]any{"name": f.toolChoice}}
|
||
}
|
||
preamble := windsurf.BuildToolPreambleForProto(tools, toolChoice)
|
||
if preamble == "" {
|
||
fmt.Fprintln(os.Stderr, "empty preamble"); os.Exit(1)
|
||
}
|
||
if f.verbose {
|
||
fmt.Printf("── Preamble (%d bytes) head 200 chars ──\n%s…\n\n",
|
||
len(preamble), truncate(preamble, 200))
|
||
}
|
||
|
||
// LS client — note: user_id/team_id are not used by LS client directly,
|
||
// only by the remote account status APIs. Warmup sends a JWT only.
|
||
lsClient := windsurf.NewLocalLSClient(f.lsPort, f.csrfToken)
|
||
_ = f.userID
|
||
_ = f.teamID
|
||
|
||
// Pick model: use given or default to Claude 4.5 Haiku (cheapest Claude)
|
||
pickedModel := f.model
|
||
if pickedModel == "" {
|
||
pickedModel = "MODEL_PRIVATE_11" // claude-4.5-haiku
|
||
}
|
||
|
||
// Warmup
|
||
ctx, cancel := context.WithTimeout(context.Background(), f.timeout)
|
||
defer cancel()
|
||
if err := lsClient.WarmupCascade(ctx, f.jwt); err != nil {
|
||
fmt.Fprintln(os.Stderr, "WarmupCascade:", err); os.Exit(1)
|
||
}
|
||
fmt.Println("✅ WarmupCascade")
|
||
|
||
// StartCascade
|
||
cascadeID, err := lsClient.StartCascade(ctx, f.jwt)
|
||
if err != nil {
|
||
fmt.Fprintln(os.Stderr, "StartCascade:", err); os.Exit(1)
|
||
}
|
||
fmt.Printf("✅ StartCascade cascade_id=%s\n", cascadeID)
|
||
|
||
// Call StreamCascadeChat (full flow incl. trajectory polling)
|
||
res, err := lsClient.StreamCascadeChat(ctx, f.jwt, pickedModel, f.prompt, preamble, cascadeID, 0)
|
||
if err != nil {
|
||
fmt.Fprintln(os.Stderr, "StreamCascadeChat:", err); os.Exit(1)
|
||
}
|
||
fmt.Printf("✅ StreamCascadeChat text_len=%d thinking_len=%d native_tool_calls=%d\n",
|
||
len(res.Text), len(res.Thinking), len(res.ToolCalls))
|
||
|
||
fmt.Println("\n── Raw Text ──")
|
||
fmt.Println(res.Text)
|
||
if res.Thinking != "" && f.verbose {
|
||
fmt.Println("\n── Thinking ──")
|
||
fmt.Println(res.Thinking)
|
||
}
|
||
|
||
// Parse tool calls from text
|
||
parsed := windsurf.ParseToolCallsFromText(res.Text)
|
||
fmt.Printf("\n── Parsed tool_calls: %d ──\n", len(parsed.ToolCalls))
|
||
for i, tc := range parsed.ToolCalls {
|
||
fmt.Printf("[%d] id=%s name=%s args=%s\n", i, tc.ID, tc.Name, tc.ArgumentsJSON)
|
||
}
|
||
fmt.Printf("\n── Text after stripping tool_call: ──\n%s\n", parsed.Text)
|
||
|
||
if len(parsed.ToolCalls) == 0 && len(res.ToolCalls) == 0 {
|
||
fmt.Fprintln(os.Stderr, "\n❌ NO TOOL CALLS produced")
|
||
os.Exit(1)
|
||
}
|
||
fmt.Println("\n✅ tool-calling E2E works")
|
||
|
||
// ───── Turn 2: inject fake tool_result and see if model continues ─────
|
||
if f.roundtrip && len(parsed.ToolCalls) > 0 {
|
||
tc := parsed.ToolCalls[0]
|
||
// Snapshot step count after Turn 1
|
||
ctxSnap, cancelSnap := context.WithTimeout(context.Background(), 10*time.Second)
|
||
stepsT1, _ := lsClient.GetTrajectorySteps(ctxSnap, cascadeID, 0)
|
||
cancelSnap()
|
||
fmt.Printf("\n── After Turn 1: trajectory has %d steps ──\n", len(stepsT1))
|
||
for i, s := range stepsT1 {
|
||
txt := s.ResponseText
|
||
if len(txt) > 80 { txt = txt[:80] + "..." }
|
||
fmt.Printf(" step[%d] type=%d text=%q\n", i, s.Type, txt)
|
||
}
|
||
|
||
fakeResult := `["cmd/server/main.go","cmd/test_windsurf_tools/main.go","internal/pkg/windsurf/tool_emulation.go"]`
|
||
turn2 := fmt.Sprintf(
|
||
`<tool_result tool_call_id="%s">%s</tool_result>`+"\n\nBased on the tool result above, tell me which files look test-related.",
|
||
tc.ID, fakeResult)
|
||
ctx2, cancel2 := context.WithTimeout(context.Background(), f.timeout)
|
||
defer cancel2()
|
||
res2, err := lsClient.StreamCascadeChat(ctx2, f.jwt, pickedModel, turn2, preamble, cascadeID, 0)
|
||
if err != nil {
|
||
fmt.Fprintln(os.Stderr, "\n❌ Turn2 StreamCascadeChat:", err)
|
||
os.Exit(1)
|
||
}
|
||
fmt.Printf("\n── Turn 2 response (text_len=%d thinking_len=%d) ──\n%s\n",
|
||
len(res2.Text), len(res2.Thinking), res2.Text)
|
||
parsed2 := windsurf.ParseToolCallsFromText(res2.Text)
|
||
fmt.Printf("\n── Turn 2 parsed tool_calls: %d ──\n", len(parsed2.ToolCalls))
|
||
for i, tc := range parsed2.ToolCalls {
|
||
fmt.Printf("[%d] id=%s name=%s args=%s\n", i, tc.ID, tc.Name, tc.ArgumentsJSON)
|
||
}
|
||
if len(parsed2.Text) > 20 && !containsIgnore(res2.Text, "i don't have access") {
|
||
fmt.Println("\n✅ round-trip works: model consumed tool_result and produced text")
|
||
} else {
|
||
fmt.Println("\n⚠️ round-trip suspicious: short or refusal text")
|
||
}
|
||
// Snapshot after Turn 2
|
||
ctxSnap2, cancelSnap2 := context.WithTimeout(context.Background(), 10*time.Second)
|
||
stepsT2, _ := lsClient.GetTrajectorySteps(ctxSnap2, cascadeID, 0)
|
||
cancelSnap2()
|
||
fmt.Printf("\n── After Turn 2: trajectory has %d steps (was %d after Turn 1) ──\n", len(stepsT2), len(stepsT1))
|
||
for i, s := range stepsT2 {
|
||
txt := s.ResponseText
|
||
if len(txt) > 80 { txt = txt[:80] + "..." }
|
||
fmt.Printf(" step[%d] type=%d text=%q\n", i, s.Type, txt)
|
||
}
|
||
}
|
||
}
|
||
|
||
func containsIgnore(haystack, needle string) bool {
|
||
return strings.Contains(strings.ToLower(haystack), strings.ToLower(needle))
|
||
}
|
||
|
||
func truncate(s string, n int) string {
|
||
if len(s) <= n {
|
||
return s
|
||
}
|
||
return s[:n] + "..."
|
||
}
|
||
|
||
var _ = strings.HasPrefix
|