win 002066e700 chore(wip): 保存订制改动以便合并上游
- windsurf: client/pool/local_ls/tool_emulation/tool_names/models 调整
- handler: admin account_data / failover_loop / gateway_handler
- repository: scheduler_cache 及测试
- service: windsurf_chat_service / windsurf_gateway_service
- deploy: compose 合并为单文件(含 windsurf-ls profile),Dockerfile.ls
- cmd: 新增 dump_ls_models / dump_preamble / test_windsurf_tools 辅助工具
2026-04-24 11:14:36 +08:00

279 lines
8.2 KiB
Go

// HTTP client for Windsurf upstream JSON/Connect-RPC endpoints.
// Portions derived from windsurf-tools (MIT 2025 shaoyu521). See ./LICENSE.
package windsurf
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
)
// Client wraps an *http.Client and the Windsurf base URL.
type Client struct {
BaseURL string
HTTP *http.Client
CSRFToken string
}
// NewClient builds a Client. proxyURL may be empty.
func NewClient(baseURL, proxyURL string, csrfToken ...string) (*Client, error) {
if baseURL == "" {
baseURL = DefaultBaseURL
}
transport := &http.Transport{
ForceAttemptHTTP2: true,
IdleConnTimeout: 90 * time.Second,
ResponseHeaderTimeout: 60 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
if proxyURL != "" {
u, err := url.Parse(proxyURL)
if err != nil {
return nil, fmt.Errorf("parse proxy: %w", err)
}
transport.Proxy = http.ProxyURL(u)
}
var csrf string
if len(csrfToken) > 0 {
csrf = csrfToken[0]
}
return &Client{
BaseURL: baseURL,
CSRFToken: csrf,
HTTP: &http.Client{
Transport: transport,
Timeout: 180 * time.Second,
},
}, nil
}
// CheckChatCapacity returns hasCapacity flag from server.
func (c *Client) CheckChatCapacity(ctx context.Context, token string) (bool, string, error) {
rawJWT := StripDevinPrefix(token)
body := map[string]any{
"metadata": map[string]any{
"apiKey": token,
"ideName": AppName,
"ideVersion": AppVersion,
"extensionName": AppName,
"extensionVersion": "0.2.0",
"sessionId": generateUUID(),
"requestId": randomUint64String(),
},
}
resp, err := c.unaryJSON(ctx, "/exa.api_server_pb.ApiServerService/CheckChatCapacity", body, rawJWT)
if err != nil {
return false, "", err
}
var out struct {
HasCapacity bool `json:"hasCapacity"`
}
if err := json.Unmarshal(resp, &out); err != nil {
return false, string(resp), fmt.Errorf("decode: %w", err)
}
return out.HasCapacity, string(resp), nil
}
// UserStatus holds the fields from GetUserStatus.
type UserStatus struct {
UserID string `json:"userId"`
TeamID string `json:"teamId"`
Name string `json:"name"`
Email string `json:"email"`
PlanName string `json:"planName,omitempty"`
DailyPercent *float64 `json:"dailyPercent,omitempty"`
WeeklyPercent *float64 `json:"weeklyPercent,omitempty"`
MonthlyPromptCredits *float64 `json:"monthlyPromptCredits,omitempty"`
UsedPromptCredits *float64 `json:"usedPromptCredits,omitempty"`
MonthlyFlexCredits *float64 `json:"monthlyFlexCredits,omitempty"`
UsedFlexCredits *float64 `json:"usedFlexCredits,omitempty"`
}
// GetUserStatus fetches the user's plan status from server.codeium.com.
func (c *Client) GetUserStatus(ctx context.Context, token string) (*UserStatus, error) {
rawJWT := StripDevinPrefix(token)
body := map[string]any{
"metadata": map[string]any{
"apiKey": token,
"ideName": AppName,
"ideVersion": AppVersion,
"extensionName": AppName,
"extensionVersion": "0.2.0",
"sessionId": generateUUID(),
"requestId": randomUint64String(),
},
}
resp, err := c.unaryJSONURL(ctx, "https://server.codeium.com/exa.api_server_pb.ApiServerService/GetUserStatus", body, rawJWT)
if err != nil {
return nil, err
}
var out struct {
UserStatus struct {
UserID string `json:"userId"`
TeamID string `json:"teamId"`
Name string `json:"name"`
Email string `json:"email"`
PlanStatus struct {
PlanInfo struct {
// 上游可能返回字符串(如 "Trial")或数字,统一用 json.RawMessage 兜底
// 再按需解析为字符串展示;避免 json.Number 遇字符串时解码失败导致整个 userStatus 拉取失败。
PlanName json.RawMessage `json:"planName"`
MonthlyPromptCredits json.Number `json:"monthlyPromptCredits"`
MonthlyFlexCredits json.Number `json:"monthlyFlexCreditPurchaseAmount"`
} `json:"planInfo"`
DailyQuotaRemainingPercent *float64 `json:"dailyQuotaRemainingPercent"`
WeeklyQuotaRemainingPercent *float64 `json:"weeklyQuotaRemainingPercent"`
UsedPromptCredits json.Number `json:"usedPromptCredits"`
UsedFlexCredits json.Number `json:"usedFlexCredits"`
} `json:"planStatus"`
} `json:"userStatus"`
}
if err := json.Unmarshal(resp, &out); err != nil {
return nil, fmt.Errorf("decode: %w (body=%s)", err, truncate(string(resp), 300))
}
us := out.UserStatus
ps := us.PlanStatus
numPtr := func(n json.Number) *float64 {
if n.String() == "" {
return nil
}
v, err := n.Float64()
if err != nil {
return nil
}
// Legacy values come in hundredths
v /= 100
return &v
}
return &UserStatus{
UserID: us.UserID,
TeamID: us.TeamID,
Name: us.Name,
Email: us.Email,
PlanName: planNameString(ps.PlanInfo.PlanName),
DailyPercent: ps.DailyQuotaRemainingPercent,
WeeklyPercent: ps.WeeklyQuotaRemainingPercent,
MonthlyPromptCredits: numPtr(ps.PlanInfo.MonthlyPromptCredits),
UsedPromptCredits: numPtr(ps.UsedPromptCredits),
MonthlyFlexCredits: numPtr(ps.PlanInfo.MonthlyFlexCredits),
UsedFlexCredits: numPtr(ps.UsedFlexCredits),
}, nil
}
// planNameString 把上游 planName 字段(可能是字符串也可能是数字)统一还原为字符串。
func planNameString(raw json.RawMessage) string {
if len(raw) == 0 {
return ""
}
var s string
if err := json.Unmarshal(raw, &s); err == nil {
return s
}
return strings.Trim(string(raw), "\"")
}
// ModelInfo is one entry of GetCascadeModelConfigs response.
type ModelInfo struct {
ModelUID string `json:"modelUid"`
Label string `json:"label"`
CreditMultiplier float64 `json:"creditMultiplier"`
IsRecommended bool `json:"isRecommended"`
IsNew bool `json:"isNew"`
}
// ListModels returns the cascade model catalog.
func (c *Client) ListModels(ctx context.Context, token string) ([]ModelInfo, error) {
rawJWT := StripDevinPrefix(token)
body := map[string]any{
"metadata": map[string]any{
"apiKey": token,
"ideName": AppName,
"ideVersion": AppVersion,
"extensionName": AppName,
"extensionVersion": "0.2.0",
"sessionId": generateUUID(),
"requestId": randomUint64String(),
},
}
resp, err := c.unaryJSON(ctx, "/exa.api_server_pb.ApiServerService/GetCascadeModelConfigs", body, rawJWT)
if err != nil {
return nil, err
}
var out struct {
ClientModelConfigs []ModelInfo `json:"clientModelConfigs"`
}
if err := json.Unmarshal(resp, &out); err != nil {
return nil, fmt.Errorf("decode: %w (body=%s)", err, truncate(string(resp), 300))
}
return out.ClientModelConfigs, nil
}
// HasModel reports whether models contains the given uid.
func HasModel(models []ModelInfo, uid string) bool {
for _, m := range models {
if strings.EqualFold(m.ModelUID, uid) {
return true
}
}
return false
}
func (c *Client) unaryJSON(ctx context.Context, path string, body any, rawJWT string) ([]byte, error) {
return c.unaryJSONURL(ctx, c.BaseURL+path, body, rawJWT)
}
func (c *Client) unaryJSONURL(ctx context.Context, fullURL string, body any, rawJWT string) ([]byte, error) {
jsonBody, err := json.Marshal(body)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(jsonBody))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Connect-Protocol-Version", "1")
req.Header.Set("User-Agent", UserAgent)
if rawJWT != "" {
req.Header.Set("Authorization", "Bearer "+rawJWT)
}
resp, err := c.HTTP.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode >= 400 {
return respBody, fmt.Errorf("HTTP %d: %s", resp.StatusCode, truncate(string(respBody), 300))
}
return respBody, nil
}
func randomUint64String() string {
var b [8]byte
_, _ = readRandom(b[:])
var v uint64
for _, x := range b {
v = (v << 8) | uint64(x)
}
v &^= 1 << 63
return fmt.Sprintf("%d", v)
}
func truncate(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "...(truncated)"
}