chore: remove LS pool implementation
Removing all LS (Language Server Pool) related code: - backend/cmd/lsworker/ - backend/internal/pkg/lspool/ - backend/internal/service/lspool_bootstrap_service.* - deploy/ls-bin/ - deploy/lsworker.Dockerfile - deploy/lsworker-entrypoint.sh Keeping: - Claude custom fingerprint (immutable) - Antigravity OAuth and telemetry improvements - TLS fingerprint SOCKS5 Docker DNS fix - Gemini OAuth security improvements Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
3ba3a17652
commit
a3f2d4577e
@ -7,7 +7,7 @@
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
ARG NODE_IMAGE=node:24-alpine
|
ARG NODE_IMAGE=node:24-alpine
|
||||||
ARG GOLANG_IMAGE=golang:1.26.1-alpine
|
ARG GOLANG_IMAGE=golang:1.25-alpine
|
||||||
ARG ALPINE_IMAGE=alpine:3.21
|
ARG ALPINE_IMAGE=alpine:3.21
|
||||||
ARG DEBIAN_IMAGE=debian:bookworm-slim
|
ARG DEBIAN_IMAGE=debian:bookworm-slim
|
||||||
ARG POSTGRES_IMAGE=postgres:18-alpine
|
ARG POSTGRES_IMAGE=postgres:18-alpine
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
FROM golang:1.25.7-alpine
|
FROM golang:1.25-alpine
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
|
|||||||
@ -1,49 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"log/slog"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/lspool"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
server, err := lspool.NewWorkerServerFromEnv()
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("failed to initialize lsworker", "err", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
httpServer := &http.Server{
|
|
||||||
Addr: envOrDefault("LSWORKER_LISTEN_ADDR", "0.0.0.0:18081"),
|
|
||||||
Handler: server.Handler(),
|
|
||||||
ReadHeaderTimeout: 10 * 1e9,
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
|
||||||
defer stop()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
<-ctx.Done()
|
|
||||||
_ = httpServer.Shutdown(context.Background())
|
|
||||||
}()
|
|
||||||
|
|
||||||
slog.Info("lsworker listening", "addr", httpServer.Addr)
|
|
||||||
if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
||||||
slog.Error("lsworker exited with error", "err", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func envOrDefault(key, fallback string) string {
|
|
||||||
if value := os.Getenv(key); value != "" {
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
return fallback
|
|
||||||
}
|
|
||||||
@ -53,9 +53,8 @@ const (
|
|||||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.107.0
|
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.107.0
|
||||||
var defaultUserAgentVersion = "1.107.0"
|
var defaultUserAgentVersion = "1.107.0"
|
||||||
|
|
||||||
|
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 覆盖
|
||||||
// defaultClientSecret 必须通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
var defaultClientSecret = "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"
|
||||||
var defaultClientSecret string
|
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
// 从环境变量读取版本号,未设置则使用默认值
|
// 从环境变量读取版本号,未设置则使用默认值
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
// Package claude provides constants and helpers for Claude API integration.
|
// Package claude provides constants and helpers for Claude API integration.
|
||||||
package claude
|
package claude
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
// Claude Code 客户端相关常量
|
// Claude Code 客户端相关常量
|
||||||
|
|
||||||
// DefaultCLIVersion 是当前模拟的 Claude CLI 版本
|
// DefaultCLIVersion 是当前模拟的 Claude CLI 版本
|
||||||
@ -30,32 +32,64 @@ const (
|
|||||||
// 这些 token 是客户端特有的,不应透传给上游 API。
|
// 这些 token 是客户端特有的,不应透传给上游 API。
|
||||||
var DroppedBetas = []string{}
|
var DroppedBetas = []string{}
|
||||||
|
|
||||||
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header(OAuth 账号,不含 context-1m)
|
||||||
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + "," + BetaContext1M + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort
|
// 使用 GetOAuthBetaHeader(modelID) 获取含 context-1m 的 model-aware 版本。
|
||||||
|
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort
|
||||||
|
|
||||||
// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header
|
// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header(OAuth,不含 context-1m)
|
||||||
//
|
//
|
||||||
// NOTE: Claude Code OAuth credentials are scoped to Claude Code. When we "mimic"
|
// NOTE: Claude Code OAuth credentials are scoped to Claude Code. When we "mimic"
|
||||||
// Claude Code for non-Claude-Code clients, we must include the claude-code beta
|
// Claude Code for non-Claude-Code clients, we must include the claude-code beta
|
||||||
// even if the request doesn't use tools, otherwise upstream may reject the
|
// even if the request doesn't use tools, otherwise upstream may reject the
|
||||||
// request as a non-Claude-Code API request.
|
// request as a non-Claude-Code API request.
|
||||||
const MessageBetaHeaderNoTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaContext1M + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort
|
const MessageBetaHeaderNoTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort
|
||||||
|
|
||||||
// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header
|
// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header(OAuth,不含 context-1m)
|
||||||
const MessageBetaHeaderWithTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaContext1M + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort
|
const MessageBetaHeaderWithTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort
|
||||||
|
|
||||||
// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
|
// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
|
||||||
const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting + "," + BetaContextManagement
|
const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting + "," + BetaContextManagement
|
||||||
|
|
||||||
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
|
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(OAuth,不含 claude-code / context-1m)
|
||||||
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking + "," + BetaEffort
|
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking + "," + BetaEffort
|
||||||
|
|
||||||
// APIKeyBetaHeader API-key 账号建议使用的 anthropic-beta header(不包含 oauth)
|
// APIKeyBetaHeader API-key 账号使用的 anthropic-beta header(不含 oauth / context-1m)
|
||||||
const APIKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + "," + BetaContext1M + "," + BetaEffort + "," + BetaPromptCachingScope
|
// 使用 GetAPIKeyBetaHeader(modelID) 获取含 context-1m 的 model-aware 版本。
|
||||||
|
const APIKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + "," + BetaEffort + "," + BetaPromptCachingScope
|
||||||
|
|
||||||
// APIKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code)
|
// APIKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不含 oauth / claude-code)
|
||||||
const APIKeyHaikuBetaHeader = BetaInterleavedThinking + "," + BetaEffort
|
const APIKeyHaikuBetaHeader = BetaInterleavedThinking + "," + BetaEffort
|
||||||
|
|
||||||
|
// ModelSupports1M 判断模型是否支持 1M context window。
|
||||||
|
// 与 claude-code-2.1.88 bundle 中 modelSupports1M 逻辑保持一致:
|
||||||
|
//
|
||||||
|
// claude-sonnet-4 系列 和 claude-opus-4-6 支持 1M context。
|
||||||
|
func ModelSupports1M(modelID string) bool {
|
||||||
|
lower := strings.ToLower(strings.TrimSpace(modelID))
|
||||||
|
return strings.Contains(lower, "claude-sonnet-4") || strings.Contains(lower, "opus-4-6")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOAuthBetaHeader 返回 OAuth 账号的 beta header。
|
||||||
|
// 仅当模型支持 1M context 时才包含 context-1m-2025-08-07。
|
||||||
|
func GetOAuthBetaHeader(modelID string) string {
|
||||||
|
if ModelSupports1M(modelID) {
|
||||||
|
return BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + "," + BetaContext1M + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort
|
||||||
|
}
|
||||||
|
return DefaultBetaHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAPIKeyBetaHeader 返回 API-key 账号的 beta header。
|
||||||
|
// 仅当模型支持 1M context 时才包含 context-1m-2025-08-07。
|
||||||
|
func GetAPIKeyBetaHeader(modelID string) string {
|
||||||
|
if strings.Contains(strings.ToLower(modelID), "haiku") {
|
||||||
|
return APIKeyHaikuBetaHeader
|
||||||
|
}
|
||||||
|
if ModelSupports1M(modelID) {
|
||||||
|
return BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + "," + BetaContext1M + "," + BetaEffort + "," + BetaPromptCachingScope
|
||||||
|
}
|
||||||
|
return APIKeyBetaHeader
|
||||||
|
}
|
||||||
|
|
||||||
// DefaultHeaders 是 Claude Code 客户端默认请求头。
|
// DefaultHeaders 是 Claude Code 客户端默认请求头。
|
||||||
var DefaultHeaders = map[string]string{
|
var DefaultHeaders = map[string]string{
|
||||||
// Keep these in sync with recent Claude CLI traffic to reduce the chance
|
// Keep these in sync with recent Claude CLI traffic to reduce the chance
|
||||||
@ -70,7 +104,7 @@ var DefaultHeaders = map[string]string{
|
|||||||
"X-Stainless-Retry-Count": "0",
|
"X-Stainless-Retry-Count": "0",
|
||||||
"X-Stainless-Timeout": "600",
|
"X-Stainless-Timeout": "600",
|
||||||
"X-App": "cli",
|
"X-App": "cli",
|
||||||
"Anthropic-Dangerous-Direct-Browser-Access": "true",
|
"anthropic-version": "2023-06-01",
|
||||||
}
|
}
|
||||||
|
|
||||||
// ApplyFingerprintOverrides 用配置覆盖默认指纹值(每个实例可设不同值)
|
// ApplyFingerprintOverrides 用配置覆盖默认指纹值(每个实例可设不同值)
|
||||||
|
|||||||
@ -1,13 +0,0 @@
|
|||||||
package lspool
|
|
||||||
|
|
||||||
import "time"
|
|
||||||
|
|
||||||
// Backend is the control-plane abstraction used by the HTTP upstream wrapper.
|
|
||||||
// It may be backed by a local in-process Pool or by remote LS workers.
|
|
||||||
type Backend interface {
|
|
||||||
GetOrCreate(accountID, routingKey string, proxyURL ...string) (*Instance, error)
|
|
||||||
SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time)
|
|
||||||
SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32)
|
|
||||||
Stats() map[string]any
|
|
||||||
Close()
|
|
||||||
}
|
|
||||||
@ -1,94 +0,0 @@
|
|||||||
// Package lspool provides LS-mode integration for the antigravity gateway.
|
|
||||||
//
|
|
||||||
// When LS mode is enabled (via ANTIGRAVITY_LS_MODE=true), requests to
|
|
||||||
// streamGenerateContent are routed through a real Language Server instance
|
|
||||||
// instead of directly to cloudcode-pa. This provides:
|
|
||||||
//
|
|
||||||
// - Authentic TLS fingerprint (Google's own Go binary)
|
|
||||||
// - Real session management and Heartbeat
|
|
||||||
// - Indistinguishable from a real IDE instance
|
|
||||||
//
|
|
||||||
// To enable: set environment variable ANTIGRAVITY_LS_MODE=true
|
|
||||||
// To configure: set ANTIGRAVITY_APP_ROOT to the AntiGravity.app path
|
|
||||||
package lspool
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log/slog"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
globalBackend Backend
|
|
||||||
globalPoolOnce sync.Once
|
|
||||||
lsModeEnabled bool
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
lsModeEnabled = os.Getenv("ANTIGRAVITY_LS_MODE") == "true"
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsLSModeEnabled returns whether LS mode is active
|
|
||||||
func IsLSModeEnabled() bool {
|
|
||||||
return lsModeEnabled
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
LSStrategyDirect = "direct"
|
|
||||||
LSStrategyJSParity = "js-parity"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CurrentLSStrategy returns the active LS routing strategy.
|
|
||||||
// Unknown values are treated as "direct" for safety.
|
|
||||||
func CurrentLSStrategy() string {
|
|
||||||
switch strings.ToLower(strings.TrimSpace(os.Getenv("ANTIGRAVITY_LS_STRATEGY"))) {
|
|
||||||
case "", LSStrategyDirect:
|
|
||||||
return LSStrategyDirect
|
|
||||||
case LSStrategyJSParity:
|
|
||||||
return LSStrategyJSParity
|
|
||||||
default:
|
|
||||||
return LSStrategyDirect
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GlobalPool returns the singleton LS pool instance
|
|
||||||
// Creates it on first call if LS mode is enabled
|
|
||||||
func GlobalPool(cfg *config.Config) Backend {
|
|
||||||
if !lsModeEnabled {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
globalPoolOnce.Do(func() {
|
|
||||||
manager, err := NewWorkerManagerFromConfig(cfg)
|
|
||||||
if err != nil {
|
|
||||||
slog.Default().Error("failed to initialize LS worker manager", "err", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
globalBackend = manager
|
|
||||||
})
|
|
||||||
return globalBackend
|
|
||||||
}
|
|
||||||
|
|
||||||
// Shutdown closes the global pool
|
|
||||||
func Shutdown() {
|
|
||||||
if globalBackend != nil {
|
|
||||||
globalBackend.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// StatusInfo returns the current LS pool status for diagnostics
|
|
||||||
func StatusInfo() map[string]any {
|
|
||||||
info := map[string]any{
|
|
||||||
"ls_mode_enabled": lsModeEnabled,
|
|
||||||
"build": "enhanced",
|
|
||||||
"user_agent": "antigravity/1.107.0",
|
|
||||||
}
|
|
||||||
if lsModeEnabled && globalBackend != nil {
|
|
||||||
stats := globalBackend.Stats()
|
|
||||||
info["pool_total"] = stats["total"]
|
|
||||||
info["pool_active"] = stats["active"]
|
|
||||||
}
|
|
||||||
return info
|
|
||||||
}
|
|
||||||
@ -1,864 +0,0 @@
|
|||||||
package lspool
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/binary"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func readConnectFrame(r io.Reader) ([]byte, error) {
|
|
||||||
header := make([]byte, 5)
|
|
||||||
if _, err := io.ReadFull(r, header); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
payloadLen := binary.BigEndian.Uint32(header[1:5])
|
|
||||||
payload := make([]byte, payloadLen)
|
|
||||||
if _, err := io.ReadFull(r, payload); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return payload, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func decodeProtoBytesField(data []byte, targetField int) []byte {
|
|
||||||
i := 0
|
|
||||||
for i < len(data) {
|
|
||||||
tag, n := binary.Uvarint(data[i:])
|
|
||||||
if n <= 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
i += n
|
|
||||||
fieldNum := int(tag >> 3)
|
|
||||||
wireType := tag & 0x7
|
|
||||||
switch wireType {
|
|
||||||
case 0:
|
|
||||||
_, n = binary.Uvarint(data[i:])
|
|
||||||
if n <= 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
i += n
|
|
||||||
case 2:
|
|
||||||
length, n := binary.Uvarint(data[i:])
|
|
||||||
if n <= 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
i += n
|
|
||||||
if i+int(length) > len(data) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if fieldNum == targetField {
|
|
||||||
return data[i : i+int(length)]
|
|
||||||
}
|
|
||||||
i += int(length)
|
|
||||||
case 1:
|
|
||||||
i += 8
|
|
||||||
case 5:
|
|
||||||
i += 4
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func decodeProtoBytesFields(data []byte, targetField int) [][]byte {
|
|
||||||
var values [][]byte
|
|
||||||
i := 0
|
|
||||||
for i < len(data) {
|
|
||||||
tag, n := binary.Uvarint(data[i:])
|
|
||||||
if n <= 0 {
|
|
||||||
return values
|
|
||||||
}
|
|
||||||
i += n
|
|
||||||
fieldNum := int(tag >> 3)
|
|
||||||
wireType := tag & 0x7
|
|
||||||
switch wireType {
|
|
||||||
case 0:
|
|
||||||
_, n = binary.Uvarint(data[i:])
|
|
||||||
if n <= 0 {
|
|
||||||
return values
|
|
||||||
}
|
|
||||||
i += n
|
|
||||||
case 2:
|
|
||||||
length, n := binary.Uvarint(data[i:])
|
|
||||||
if n <= 0 {
|
|
||||||
return values
|
|
||||||
}
|
|
||||||
i += n
|
|
||||||
if i+int(length) > len(data) {
|
|
||||||
return values
|
|
||||||
}
|
|
||||||
if fieldNum == targetField {
|
|
||||||
values = append(values, append([]byte(nil), data[i:i+int(length)]...))
|
|
||||||
}
|
|
||||||
i += int(length)
|
|
||||||
case 1:
|
|
||||||
i += 8
|
|
||||||
case 5:
|
|
||||||
i += 4
|
|
||||||
default:
|
|
||||||
return values
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return values
|
|
||||||
}
|
|
||||||
|
|
||||||
func decodeTopicRows(topic []byte) map[string]string {
|
|
||||||
rows := make(map[string]string)
|
|
||||||
for _, entry := range decodeProtoBytesFields(topic, 1) {
|
|
||||||
key := decodeProtoString(entry, 1)
|
|
||||||
row := decodeProtoBytesField(entry, 2)
|
|
||||||
rows[key] = decodeProtoString(row, 1)
|
|
||||||
}
|
|
||||||
return rows
|
|
||||||
}
|
|
||||||
|
|
||||||
func requireBase64PrimitiveValue(t *testing.T, got string, want []byte) {
|
|
||||||
t.Helper()
|
|
||||||
decoded, err := base64.StdEncoding.DecodeString(got)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, want, decoded)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestMockExtensionServerTokenInjection verifies the token injection flow:
|
|
||||||
// Extension → MockExtensionServer → LS subscribes uss-oauth → gets OAuthTokenInfo
|
|
||||||
func TestMockExtensionServerTokenInjection(t *testing.T) {
|
|
||||||
csrf := "test-csrf-token"
|
|
||||||
srv, err := NewMockExtensionServer(csrf)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer srv.Close()
|
|
||||||
|
|
||||||
// 1. Set token for an account
|
|
||||||
srv.SetToken("account-1", &TokenInfo{
|
|
||||||
AccessToken: "ya29.test-access-token",
|
|
||||||
RefreshToken: "1//test-refresh-token",
|
|
||||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
|
||||||
})
|
|
||||||
|
|
||||||
// 2. Verify token is stored
|
|
||||||
srv.mu.RLock()
|
|
||||||
info, ok := srv.tokens["account-1"]
|
|
||||||
srv.mu.RUnlock()
|
|
||||||
require.True(t, ok)
|
|
||||||
require.Equal(t, "ya29.test-access-token", info.AccessToken)
|
|
||||||
require.Equal(t, "1//test-refresh-token", info.RefreshToken)
|
|
||||||
require.False(t, info.ExpiresAt.IsZero())
|
|
||||||
|
|
||||||
// 3. Simulate LS subscribing to uss-oauth (HTTP request to mock server)
|
|
||||||
req, _ := http.NewRequest("POST",
|
|
||||||
fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/SubscribeToUnifiedStateSyncTopic", srv.Port()),
|
|
||||||
bytes.NewReader(frameConnectMessage(encodeProtoString(1, "uss-oauth"))))
|
|
||||||
req.Header.Set("x-codeium-csrf-token", csrf)
|
|
||||||
req.Header.Set("Content-Type", "application/connect+proto")
|
|
||||||
|
|
||||||
// The stream handler will block, so run in background and cancel after we confirm connection
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
req = req.WithContext(ctx)
|
|
||||||
|
|
||||||
client := &http.Client{}
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err == nil {
|
|
||||||
defer resp.Body.Close()
|
|
||||||
require.Equal(t, 200, resp.StatusCode)
|
|
||||||
require.Equal(t, "application/connect+proto", resp.Header.Get("Content-Type"))
|
|
||||||
|
|
||||||
// Read the first envelope frame (initial state)
|
|
||||||
header := make([]byte, 5)
|
|
||||||
n, readErr := resp.Body.Read(header)
|
|
||||||
if readErr == nil && n == 5 {
|
|
||||||
require.Equal(t, byte(0x00), header[0], "first byte should be 0x00 (data frame)")
|
|
||||||
t.Logf("Received initial state frame: flags=%d, payload_len=%d", header[0], header[1:5])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestMockExtensionServerCSRF verifies CSRF token validation
|
|
||||||
func TestMockExtensionServerCSRF(t *testing.T) {
|
|
||||||
csrf := "correct-csrf"
|
|
||||||
srv, err := NewMockExtensionServer(csrf)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer srv.Close()
|
|
||||||
|
|
||||||
base := fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/Heartbeat", srv.Port())
|
|
||||||
|
|
||||||
// Wrong CSRF → 403
|
|
||||||
req, _ := http.NewRequest("POST", base, nil)
|
|
||||||
req.Header.Set("x-codeium-csrf-token", "wrong-csrf")
|
|
||||||
req.Header.Set("Content-Type", "application/proto")
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer resp.Body.Close()
|
|
||||||
require.Equal(t, 403, resp.StatusCode)
|
|
||||||
|
|
||||||
// Correct CSRF → 200
|
|
||||||
req2, _ := http.NewRequest("POST", base, nil)
|
|
||||||
req2.Header.Set("x-codeium-csrf-token", csrf)
|
|
||||||
req2.Header.Set("Content-Type", "application/proto")
|
|
||||||
resp2, err := http.DefaultClient.Do(req2)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer resp2.Body.Close()
|
|
||||||
require.Equal(t, 200, resp2.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestMockExtensionServerGetSecretValue verifies the fallback token path
|
|
||||||
func TestMockExtensionServerGetSecretValue(t *testing.T) {
|
|
||||||
csrf := "test-csrf"
|
|
||||||
srv, err := NewMockExtensionServer(csrf)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer srv.Close()
|
|
||||||
|
|
||||||
srv.SetToken("acc", &TokenInfo{AccessToken: "ya29.secret-token"})
|
|
||||||
|
|
||||||
// GetSecretValue should return the token
|
|
||||||
req, _ := http.NewRequest("POST",
|
|
||||||
fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/GetSecretValue", srv.Port()),
|
|
||||||
nil)
|
|
||||||
req.Header.Set("x-codeium-csrf-token", csrf)
|
|
||||||
req.Header.Set("Content-Type", "application/proto")
|
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer resp.Body.Close()
|
|
||||||
require.Equal(t, 200, resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestOAuthTokenInfoProto verifies the proto encoding matches real IDE format
|
|
||||||
func TestOAuthTokenInfoProto(t *testing.T) {
|
|
||||||
expiry := time.Date(2026, 3, 29, 19, 0, 0, 0, time.UTC)
|
|
||||||
bin := buildOAuthTokenInfoBinary("ya29.test", "1//refresh", expiry)
|
|
||||||
|
|
||||||
// Verify fields are present by checking proto wire format
|
|
||||||
require.True(t, len(bin) > 0, "proto should not be empty")
|
|
||||||
|
|
||||||
// Field 1 (access_token): tag=0x0a, value="ya29.test"
|
|
||||||
require.Contains(t, string(bin), "ya29.test")
|
|
||||||
// Field 2 (token_type): tag=0x12, value="Bearer"
|
|
||||||
require.Contains(t, string(bin), "Bearer")
|
|
||||||
// Field 3 (refresh_token): tag=0x1a, value="1//refresh"
|
|
||||||
require.Contains(t, string(bin), "1//refresh")
|
|
||||||
|
|
||||||
// Without refresh_token
|
|
||||||
binNoRefresh := buildOAuthTokenInfoBinary("ya29.test", "", expiry)
|
|
||||||
require.NotContains(t, string(binNoRefresh), "1//refresh")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestOAuthTokenInfoWithRealExpiry verifies expiry uses real time, not hardcoded
|
|
||||||
func TestOAuthTokenInfoWithRealExpiry(t *testing.T) {
|
|
||||||
future := time.Now().Add(2 * time.Hour)
|
|
||||||
bin := buildOAuthTokenInfoBinary("token", "refresh", future)
|
|
||||||
|
|
||||||
// Zero expiry should default to ~1h
|
|
||||||
binZero := buildOAuthTokenInfoBinary("token", "refresh", time.Time{})
|
|
||||||
|
|
||||||
// They should be different lengths or content (different expiry timestamps)
|
|
||||||
// Both should be valid (non-empty)
|
|
||||||
require.True(t, len(bin) > 0)
|
|
||||||
require.True(t, len(binZero) > 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestUSSTopicWithOAuth verifies the full USS topic proto structure
|
|
||||||
func TestUSSTopicWithOAuth(t *testing.T) {
|
|
||||||
expiry := time.Now().Add(1 * time.Hour)
|
|
||||||
topic := buildUSSTopicWithOAuth("ya29.access", "1//refresh", expiry)
|
|
||||||
|
|
||||||
require.True(t, len(topic) > 0)
|
|
||||||
// The topic should contain the sentinel key
|
|
||||||
require.Contains(t, string(topic), "oauthTokenInfoSentinelKey")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUSSTopicWithModelCredits(t *testing.T) {
|
|
||||||
available := int32(123)
|
|
||||||
minimum := int32(50)
|
|
||||||
topic := buildUSSTopicWithModelCredits(&ModelCreditsInfo{
|
|
||||||
UseAICredits: true,
|
|
||||||
AvailableCredits: &available,
|
|
||||||
MinimumCreditAmountForUsage: &minimum,
|
|
||||||
})
|
|
||||||
|
|
||||||
require.True(t, len(topic) > 0)
|
|
||||||
require.Contains(t, string(topic), useAICreditsSentinelKey)
|
|
||||||
require.Contains(t, string(topic), availableCreditsSentinelKey)
|
|
||||||
require.Contains(t, string(topic), minimumCreditAmountForUsageKey)
|
|
||||||
|
|
||||||
rows := decodeTopicRows(topic)
|
|
||||||
requireBase64PrimitiveValue(t, rows[useAICreditsSentinelKey], buildPrimitiveBoolBinary(true))
|
|
||||||
requireBase64PrimitiveValue(t, rows[availableCreditsSentinelKey], buildPrimitiveInt32Binary(available))
|
|
||||||
requireBase64PrimitiveValue(t, rows[minimumCreditAmountForUsageKey], buildPrimitiveInt32Binary(minimum))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMockExtensionServerModelCreditsDynamicUpdate(t *testing.T) {
|
|
||||||
csrf := "test-csrf-token"
|
|
||||||
srv, err := NewMockExtensionServer(csrf)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer srv.Close()
|
|
||||||
|
|
||||||
srv.SetModelCredits("account-1", &ModelCreditsInfo{})
|
|
||||||
|
|
||||||
req, _ := http.NewRequest("POST",
|
|
||||||
fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/SubscribeToUnifiedStateSyncTopic", srv.Port()),
|
|
||||||
bytes.NewReader(frameConnectMessage(encodeProtoString(1, "uss-modelCredits"))))
|
|
||||||
req.Header.Set("x-codeium-csrf-token", csrf)
|
|
||||||
req.Header.Set("Content-Type", "application/connect+proto")
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
req = req.WithContext(ctx)
|
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer resp.Body.Close()
|
|
||||||
require.Equal(t, 200, resp.StatusCode)
|
|
||||||
|
|
||||||
// Drain the initial_state frame first.
|
|
||||||
_, err = readConnectFrame(resp.Body)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
available := int32(77)
|
|
||||||
minimum := int32(25)
|
|
||||||
srv.SetModelCredits("account-1", &ModelCreditsInfo{
|
|
||||||
UseAICredits: true,
|
|
||||||
AvailableCredits: &available,
|
|
||||||
MinimumCreditAmountForUsage: &minimum,
|
|
||||||
})
|
|
||||||
|
|
||||||
values := make(map[string]string, 3)
|
|
||||||
for len(values) < 3 {
|
|
||||||
frame, readErr := readConnectFrame(resp.Body)
|
|
||||||
require.NoError(t, readErr)
|
|
||||||
applied := decodeProtoBytesField(frame, 2)
|
|
||||||
require.NotEmpty(t, applied)
|
|
||||||
key := decodeProtoString(applied, 1)
|
|
||||||
row := decodeProtoBytesField(applied, 2)
|
|
||||||
values[key] = decodeProtoString(row, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
require.Contains(t, values, useAICreditsSentinelKey)
|
|
||||||
require.Contains(t, values, availableCreditsSentinelKey)
|
|
||||||
require.Contains(t, values, minimumCreditAmountForUsageKey)
|
|
||||||
requireBase64PrimitiveValue(t, values[useAICreditsSentinelKey], buildPrimitiveBoolBinary(true))
|
|
||||||
requireBase64PrimitiveValue(t, values[availableCreditsSentinelKey], buildPrimitiveInt32Binary(available))
|
|
||||||
requireBase64PrimitiveValue(t, values[minimumCreditAmountForUsageKey], buildPrimitiveInt32Binary(minimum))
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestBuildInitialStateUpdate verifies the USS update wrapper
|
|
||||||
func TestBuildInitialStateUpdate(t *testing.T) {
|
|
||||||
topicData := buildEmptyTopic()
|
|
||||||
update := buildInitialStateUpdate(topicData)
|
|
||||||
// Should be a valid proto bytes field (field 1 = initial_state)
|
|
||||||
require.True(t, len(update) >= 0) // empty topic is valid
|
|
||||||
|
|
||||||
topicData2 := buildUSSTopicWithOAuth("token", "refresh", time.Now().Add(1*time.Hour))
|
|
||||||
update2 := buildInitialStateUpdate(topicData2)
|
|
||||||
require.True(t, len(update2) > len(update), "non-empty topic should produce larger update")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestPoolSetAccountTokenComplete verifies pool accepts full credential set
|
|
||||||
func TestPoolSetAccountTokenComplete(t *testing.T) {
|
|
||||||
csrf := "pool-csrf"
|
|
||||||
srv, err := NewMockExtensionServer(csrf)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer srv.Close()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
pool := &Pool{
|
|
||||||
config: DefaultConfig(),
|
|
||||||
instances: make(map[string][]*Instance),
|
|
||||||
extServer: srv,
|
|
||||||
ctx: ctx,
|
|
||||||
cancel: cancel,
|
|
||||||
}
|
|
||||||
|
|
||||||
expiry := time.Now().Add(1 * time.Hour)
|
|
||||||
pool.SetAccountToken("acc-1", "ya29.full-token", "1//full-refresh", expiry)
|
|
||||||
|
|
||||||
srv.mu.RLock()
|
|
||||||
info := srv.tokens["acc-1"]
|
|
||||||
srv.mu.RUnlock()
|
|
||||||
|
|
||||||
require.NotNil(t, info)
|
|
||||||
require.Equal(t, "ya29.full-token", info.AccessToken)
|
|
||||||
require.Equal(t, "1//full-refresh", info.RefreshToken)
|
|
||||||
require.False(t, info.ExpiresAt.IsZero())
|
|
||||||
require.WithinDuration(t, expiry, info.ExpiresAt, time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPoolSetAccountModelCreditsComplete(t *testing.T) {
|
|
||||||
csrf := "pool-csrf"
|
|
||||||
srv, err := NewMockExtensionServer(csrf)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer srv.Close()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
pool := &Pool{
|
|
||||||
config: DefaultConfig(),
|
|
||||||
instances: make(map[string][]*Instance),
|
|
||||||
extServer: srv,
|
|
||||||
ctx: ctx,
|
|
||||||
cancel: cancel,
|
|
||||||
}
|
|
||||||
|
|
||||||
available := int32(77)
|
|
||||||
minimum := int32(25)
|
|
||||||
pool.SetAccountModelCredits("acc-1", true, &available, &minimum)
|
|
||||||
|
|
||||||
srv.mu.RLock()
|
|
||||||
info := srv.credits["acc-1"]
|
|
||||||
srv.mu.RUnlock()
|
|
||||||
|
|
||||||
require.NotNil(t, info)
|
|
||||||
require.True(t, info.UseAICredits)
|
|
||||||
require.NotNil(t, info.AvailableCredits)
|
|
||||||
require.Equal(t, available, *info.AvailableCredits)
|
|
||||||
require.NotNil(t, info.MinimumCreditAmountForUsage)
|
|
||||||
require.Equal(t, minimum, *info.MinimumCreditAmountForUsage)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestUpstreamAdapterExtractsCredentials verifies internal LS headers are extracted and stripped.
|
|
||||||
func TestUpstreamAdapterExtractsCredentials(t *testing.T) {
|
|
||||||
// Create a mock upstream that records what it receives
|
|
||||||
var receivedHeaders http.Header
|
|
||||||
var mu sync.Mutex
|
|
||||||
fallback := &recordingUpstreamWithCallback{}
|
|
||||||
fallback.onDo = func(req *http.Request) {
|
|
||||||
mu.Lock()
|
|
||||||
receivedHeaders = req.Header.Clone()
|
|
||||||
mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
csrf := "test-csrf"
|
|
||||||
srv, err := NewMockExtensionServer(csrf)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer srv.Close()
|
|
||||||
|
|
||||||
pool := &Pool{
|
|
||||||
config: DefaultConfig(),
|
|
||||||
instances: make(map[string][]*Instance),
|
|
||||||
extServer: srv,
|
|
||||||
}
|
|
||||||
|
|
||||||
upstream := NewLSPoolUpstream(pool, fallback)
|
|
||||||
|
|
||||||
// Non-streamGenerateContent request → should pass through to fallback
|
|
||||||
req, _ := http.NewRequest("POST", "https://example.com/v1beta/models/gemini:generateContent", nil)
|
|
||||||
req.Header.Set("Authorization", "Bearer ya29.test")
|
|
||||||
req.Header.Set("X-Antigravity-Refresh-Token", "1//secret-refresh")
|
|
||||||
req.Header.Set("X-Antigravity-Token-Expiry", "2026-03-29T19:00:00Z")
|
|
||||||
req.Header.Set(useAICreditsHeader, "true")
|
|
||||||
req.Header.Set(availableCreditsHeader, "42")
|
|
||||||
req.Header.Set(minimumCreditAmountHeader, "50")
|
|
||||||
|
|
||||||
resp, err := upstream.Do(req, "", 1, 1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, resp)
|
|
||||||
|
|
||||||
// Internal headers should never leak to the direct upstream.
|
|
||||||
mu.Lock()
|
|
||||||
require.Empty(t, receivedHeaders.Get("X-Antigravity-Refresh-Token"))
|
|
||||||
require.Empty(t, receivedHeaders.Get("X-Antigravity-Token-Expiry"))
|
|
||||||
require.Empty(t, receivedHeaders.Get(useAICreditsHeader))
|
|
||||||
require.Empty(t, receivedHeaders.Get(availableCreditsHeader))
|
|
||||||
require.Empty(t, receivedHeaders.Get(minimumCreditAmountHeader))
|
|
||||||
mu.Unlock()
|
|
||||||
|
|
||||||
srv.mu.RLock()
|
|
||||||
tokenInfo := srv.tokens["1"]
|
|
||||||
creditsInfo := srv.credits["1"]
|
|
||||||
srv.mu.RUnlock()
|
|
||||||
|
|
||||||
require.NotNil(t, tokenInfo)
|
|
||||||
require.Equal(t, "ya29.test", tokenInfo.AccessToken)
|
|
||||||
require.NotNil(t, creditsInfo)
|
|
||||||
require.True(t, creditsInfo.UseAICredits)
|
|
||||||
require.NotNil(t, creditsInfo.AvailableCredits)
|
|
||||||
require.Equal(t, int32(42), *creditsInfo.AvailableCredits)
|
|
||||||
require.NotNil(t, creditsInfo.MinimumCreditAmountForUsage)
|
|
||||||
require.Equal(t, int32(50), *creditsInfo.MinimumCreditAmountForUsage)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestExtractPromptAndModelMultiTurn verifies multi-turn prompt extraction
|
|
||||||
func TestExtractPromptAndModelMultiTurn(t *testing.T) {
|
|
||||||
body := `{
|
|
||||||
"model": "claude-sonnet-4-6",
|
|
||||||
"request": {
|
|
||||||
"systemInstruction": {"parts": [{"text": "You are helpful"}]},
|
|
||||||
"contents": [
|
|
||||||
{"role": "user", "parts": [{"text": "Hello"}]},
|
|
||||||
{"role": "model", "parts": [{"text": "Hi there!"}]},
|
|
||||||
{"role": "user", "parts": [{"text": "How are you?"}]}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}`
|
|
||||||
prompt, model := extractPromptAndModel([]byte(body))
|
|
||||||
require.Equal(t, "claude-sonnet-4-6", model)
|
|
||||||
require.Contains(t, prompt, "You are helpful")
|
|
||||||
require.Contains(t, prompt, "Hello")
|
|
||||||
require.Contains(t, prompt, "Hi there!")
|
|
||||||
require.Contains(t, prompt, "How are you?")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestExtractUsageFromTrajectory verifies token usage extraction
|
|
||||||
func TestExtractUsageFromTrajectory(t *testing.T) {
|
|
||||||
resp := `{
|
|
||||||
"trajectory": {
|
|
||||||
"steps": [{
|
|
||||||
"type": "CORTEX_STEP_TYPE_PLANNER_RESPONSE",
|
|
||||||
"status": "CORTEX_STEP_STATUS_DONE",
|
|
||||||
"plannerResponse": {"response": "OK"},
|
|
||||||
"metadata": {
|
|
||||||
"modelUsage": {
|
|
||||||
"inputTokens": "150",
|
|
||||||
"outputTokens": "5"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}]
|
|
||||||
}
|
|
||||||
}`
|
|
||||||
usage := extractUsageFromTrajectory([]byte(resp))
|
|
||||||
require.NotNil(t, usage)
|
|
||||||
require.Equal(t, 150, usage["promptTokenCount"])
|
|
||||||
require.Equal(t, 5, usage["candidatesTokenCount"])
|
|
||||||
require.Equal(t, 155, usage["totalTokenCount"])
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestSSEChunkFormat verifies the Gemini SSE output format
|
|
||||||
func TestSSEChunkFormat(t *testing.T) {
|
|
||||||
chunk := buildGeminiSSEChunk("Hello world")
|
|
||||||
require.True(t, len(chunk) > 0)
|
|
||||||
require.Contains(t, chunk, "data: ")
|
|
||||||
require.Contains(t, chunk, `"text":"Hello world"`)
|
|
||||||
require.Contains(t, chunk, `"role":"model"`)
|
|
||||||
require.True(t, chunk[len(chunk)-2:] == "\n\n")
|
|
||||||
|
|
||||||
// Verify it's valid JSON after stripping "data: " prefix
|
|
||||||
jsonStr := chunk[len("data: ") : len(chunk)-2]
|
|
||||||
var parsed map[string]any
|
|
||||||
err := json.Unmarshal([]byte(jsonStr), &parsed)
|
|
||||||
require.NoError(t, err)
|
|
||||||
response := parsed["response"].(map[string]any)
|
|
||||||
candidates := response["candidates"].([]any)
|
|
||||||
require.Len(t, candidates, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestSSEFinalChunkFormat verifies the final SSE chunk with usage
|
|
||||||
func TestSSEFinalChunkFormat(t *testing.T) {
|
|
||||||
usage := map[string]any{
|
|
||||||
"promptTokenCount": 100,
|
|
||||||
"candidatesTokenCount": 50,
|
|
||||||
"totalTokenCount": 150,
|
|
||||||
}
|
|
||||||
chunk := buildGeminiSSEFinalChunk(usage)
|
|
||||||
require.Contains(t, chunk, "data: ")
|
|
||||||
require.Contains(t, chunk, `"finishReason":"STOP"`)
|
|
||||||
require.Contains(t, chunk, `"usageMetadata"`)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStreamCascadeResponsePollsImmediately(t *testing.T) {
|
|
||||||
var getCalls atomic.Int32
|
|
||||||
|
|
||||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token"))
|
|
||||||
|
|
||||||
if strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory") {
|
|
||||||
getCalls.Add(1)
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
_, _ = w.Write([]byte(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"hello from ls"}}]}}`))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
http.NotFound(w, r)
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
inst := &Instance{
|
|
||||||
AccountID: "42",
|
|
||||||
CSRF: "test-csrf",
|
|
||||||
Address: strings.TrimPrefix(server.URL, "https://"),
|
|
||||||
client: server.Client(),
|
|
||||||
healthy: true,
|
|
||||||
lastUsed: time.Now(),
|
|
||||||
}
|
|
||||||
upstream := NewLSPoolUpstream(&Pool{}, &recordingUpstream{})
|
|
||||||
|
|
||||||
pr, pw := io.Pipe()
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
upstream.streamCascadeResponse(ctx, inst, "cid-1", pw, nil, nil)
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
body, err := io.ReadAll(pr)
|
|
||||||
require.NoError(t, err)
|
|
||||||
<-done
|
|
||||||
|
|
||||||
require.GreaterOrEqual(t, getCalls.Load(), int32(1))
|
|
||||||
require.Contains(t, string(body), "hello from ls")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRequestHasToolsEdgeCases verifies tool detection edge cases
|
|
||||||
func TestRequestHasToolsEdgeCases(t *testing.T) {
|
|
||||||
// null tools
|
|
||||||
require.False(t, requestHasTools([]byte(`{"contents":[],"tools":null}`)))
|
|
||||||
// tools with empty function declarations
|
|
||||||
require.True(t, requestHasTools([]byte(`{"contents":[],"tools":[{"functionDeclarations":[]}]}`)))
|
|
||||||
// deeply nested wrapped format
|
|
||||||
require.True(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[],"tools":[{"codeExecution":{}}]}}`)))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestJSParityRouteReusesCascadeSession(t *testing.T) {
|
|
||||||
t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity)
|
|
||||||
|
|
||||||
var startCalls atomic.Int32
|
|
||||||
var sendCalls atomic.Int32
|
|
||||||
var getCalls atomic.Int32
|
|
||||||
var sendBodiesMu sync.Mutex
|
|
||||||
var sendBodies []map[string]any
|
|
||||||
|
|
||||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token"))
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case strings.HasSuffix(r.URL.Path, "/StartCascade"):
|
|
||||||
startCalls.Add(1)
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
_, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`))
|
|
||||||
case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"):
|
|
||||||
sendCalls.Add(1)
|
|
||||||
var payload map[string]any
|
|
||||||
err := json.NewDecoder(r.Body).Decode(&payload)
|
|
||||||
require.NoError(t, err)
|
|
||||||
sendBodiesMu.Lock()
|
|
||||||
sendBodies = append(sendBodies, payload)
|
|
||||||
sendBodiesMu.Unlock()
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
_, _ = w.Write([]byte(`{"queued":false}`))
|
|
||||||
case strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory"):
|
|
||||||
call := getCalls.Add(1)
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
text := "hello from ls"
|
|
||||||
if call > 1 {
|
|
||||||
text = "follow up from ls"
|
|
||||||
}
|
|
||||||
_, _ = w.Write([]byte(fmt.Sprintf(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"%s"}}]}}`, text)))
|
|
||||||
default:
|
|
||||||
http.NotFound(w, r)
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
inst := &Instance{
|
|
||||||
AccountID: "42",
|
|
||||||
CSRF: "test-csrf",
|
|
||||||
Address: strings.TrimPrefix(server.URL, "https://"),
|
|
||||||
client: server.Client(),
|
|
||||||
healthy: true,
|
|
||||||
lastUsed: time.Now(),
|
|
||||||
}
|
|
||||||
inst.SetModelMappingReady(true)
|
|
||||||
pool := &Pool{
|
|
||||||
config: Config{ReplicasPerAccount: 1},
|
|
||||||
instances: map[string][]*Instance{"42": []*Instance{inst}},
|
|
||||||
}
|
|
||||||
upstream := NewLSPoolUpstream(pool, &recordingUpstream{})
|
|
||||||
|
|
||||||
req1Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
|
|
||||||
req1, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req1Body))
|
|
||||||
require.NoError(t, err)
|
|
||||||
req1.Header.Set("Authorization", "Bearer downstream-a")
|
|
||||||
|
|
||||||
resp1, err := upstream.Do(req1, "", 42, 1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
body1, err := io.ReadAll(resp1.Body)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Contains(t, string(body1), `"text":"hello from ls"`)
|
|
||||||
|
|
||||||
req2Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]},{"role":"model","parts":[{"text":"hello from ls"}]},{"role":"user","parts":[{"text":"follow up"}]}]}}`)
|
|
||||||
req2, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req2Body))
|
|
||||||
require.NoError(t, err)
|
|
||||||
req2.Header.Set("Authorization", "Bearer downstream-a")
|
|
||||||
|
|
||||||
resp2, err := upstream.Do(req2, "", 42, 1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
body2, err := io.ReadAll(resp2.Body)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Contains(t, string(body2), `"text":"follow up from ls"`)
|
|
||||||
|
|
||||||
require.Equal(t, int32(1), startCalls.Load(), "cascade should be reused for append-only transcript")
|
|
||||||
require.Equal(t, int32(2), sendCalls.Load())
|
|
||||||
|
|
||||||
sendBodiesMu.Lock()
|
|
||||||
require.Len(t, sendBodies, 2)
|
|
||||||
firstSend := sendBodies[0]
|
|
||||||
sendBodiesMu.Unlock()
|
|
||||||
|
|
||||||
require.Equal(t, "cid-1", firstSend["cascadeId"])
|
|
||||||
require.Equal(t, false, firstSend["blocking"])
|
|
||||||
metadata, ok := firstSend["metadata"].(map[string]any)
|
|
||||||
require.True(t, ok)
|
|
||||||
require.Equal(t, "antigravity", metadata["ideName"])
|
|
||||||
require.Equal(t, "1.107.0", metadata["ideVersion"])
|
|
||||||
require.NotContains(t, firstSend, "clientType")
|
|
||||||
require.NotContains(t, firstSend, "messageOrigin")
|
|
||||||
cascadeConfig, ok := firstSend["cascadeConfig"].(map[string]any)
|
|
||||||
require.True(t, ok)
|
|
||||||
plannerConfig, ok := cascadeConfig["plannerConfig"].(map[string]any)
|
|
||||||
require.True(t, ok)
|
|
||||||
requestedModel, ok := plannerConfig["requestedModel"].(map[string]any)
|
|
||||||
require.True(t, ok)
|
|
||||||
require.NotEmpty(t, requestedModel["model"])
|
|
||||||
require.Len(t, plannerConfig, 1)
|
|
||||||
require.Len(t, cascadeConfig, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestJSParityRouteFallsBackOnSystemInstructionDrift(t *testing.T) {
|
|
||||||
t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity)
|
|
||||||
|
|
||||||
var startCalls atomic.Int32
|
|
||||||
var sendCalls atomic.Int32
|
|
||||||
|
|
||||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token"))
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case strings.HasSuffix(r.URL.Path, "/StartCascade"):
|
|
||||||
startCalls.Add(1)
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
_, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`))
|
|
||||||
case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"):
|
|
||||||
sendCalls.Add(1)
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
_, _ = w.Write([]byte(`{"queued":false}`))
|
|
||||||
case strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory"):
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
_, _ = w.Write([]byte(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"hello from ls"}}]}}`))
|
|
||||||
default:
|
|
||||||
http.NotFound(w, r)
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
inst := &Instance{
|
|
||||||
AccountID: "42",
|
|
||||||
CSRF: "test-csrf",
|
|
||||||
Address: strings.TrimPrefix(server.URL, "https://"),
|
|
||||||
client: server.Client(),
|
|
||||||
healthy: true,
|
|
||||||
lastUsed: time.Now(),
|
|
||||||
}
|
|
||||||
inst.SetModelMappingReady(true)
|
|
||||||
fallback := &recordingUpstream{}
|
|
||||||
pool := &Pool{
|
|
||||||
config: Config{ReplicasPerAccount: 1},
|
|
||||||
instances: map[string][]*Instance{"42": []*Instance{inst}},
|
|
||||||
}
|
|
||||||
upstream := NewLSPoolUpstream(pool, fallback)
|
|
||||||
|
|
||||||
req1Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
|
|
||||||
req1, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req1Body))
|
|
||||||
require.NoError(t, err)
|
|
||||||
req1.Header.Set("Authorization", "Bearer downstream-a")
|
|
||||||
|
|
||||||
resp1, err := upstream.Do(req1, "", 42, 1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
body1, err := io.ReadAll(resp1.Body)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Contains(t, string(body1), `"text":"hello from ls"`)
|
|
||||||
|
|
||||||
req2Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","systemInstruction":{"parts":[{"text":"You are different"}]},"contents":[{"role":"user","parts":[{"text":"hello"}]},{"role":"model","parts":[{"text":"hello from ls"}]},{"role":"user","parts":[{"text":"follow up"}]}]}}`)
|
|
||||||
req2, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req2Body))
|
|
||||||
require.NoError(t, err)
|
|
||||||
req2.Header.Set("Authorization", "Bearer downstream-a")
|
|
||||||
|
|
||||||
resp2, err := upstream.Do(req2, "", 42, 1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
body2, err := io.ReadAll(resp2.Body)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.Equal(t, "ok", string(body2))
|
|
||||||
require.Equal(t, 1, fallback.doCalls)
|
|
||||||
require.Equal(t, int32(1), startCalls.Load())
|
|
||||||
require.Equal(t, int32(1), sendCalls.Load())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestJSParityRouteErrorsWhenModelMappingPending(t *testing.T) {
|
|
||||||
t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity)
|
|
||||||
|
|
||||||
var startCalls atomic.Int32
|
|
||||||
var sendCalls atomic.Int32
|
|
||||||
|
|
||||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
switch {
|
|
||||||
case strings.HasSuffix(r.URL.Path, "/StartCascade"):
|
|
||||||
startCalls.Add(1)
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
_, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`))
|
|
||||||
case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"):
|
|
||||||
sendCalls.Add(1)
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
_, _ = w.Write([]byte(`{"queued":false}`))
|
|
||||||
default:
|
|
||||||
http.NotFound(w, r)
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
inst := &Instance{
|
|
||||||
AccountID: "42",
|
|
||||||
CSRF: "test-csrf",
|
|
||||||
Address: strings.TrimPrefix(server.URL, "https://"),
|
|
||||||
client: server.Client(),
|
|
||||||
healthy: true,
|
|
||||||
lastUsed: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
fallback := &recordingUpstream{}
|
|
||||||
pool := &Pool{
|
|
||||||
config: Config{ReplicasPerAccount: 1},
|
|
||||||
instances: map[string][]*Instance{"42": []*Instance{inst}},
|
|
||||||
}
|
|
||||||
upstream := NewLSPoolUpstream(pool, fallback)
|
|
||||||
|
|
||||||
reqBody := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
|
|
||||||
req, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(reqBody))
|
|
||||||
require.NoError(t, err)
|
|
||||||
req.Header.Set("Authorization", "Bearer downstream-a")
|
|
||||||
|
|
||||||
resp, err := upstream.Do(req, "", 42, 1)
|
|
||||||
require.Nil(t, resp)
|
|
||||||
require.ErrorIs(t, err, errLSModelMapPending)
|
|
||||||
require.Equal(t, int32(0), startCalls.Load())
|
|
||||||
require.Equal(t, int32(0), sendCalls.Load())
|
|
||||||
require.Equal(t, 0, fallback.doCalls)
|
|
||||||
}
|
|
||||||
|
|
||||||
// recordingUpstreamWithCallback extends the base recordingUpstream with a callback
|
|
||||||
type recordingUpstreamWithCallback struct {
|
|
||||||
recordingUpstream
|
|
||||||
onDo func(req *http.Request)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *recordingUpstreamWithCallback) Do(req *http.Request, proxyURL string, accountID int64, c int) (*http.Response, error) {
|
|
||||||
if r.onDo != nil {
|
|
||||||
r.onDo(req)
|
|
||||||
}
|
|
||||||
return r.recordingUpstream.Do(req, proxyURL, accountID, c)
|
|
||||||
}
|
|
||||||
@ -1,920 +0,0 @@
|
|||||||
// Package lspool provides a mock Extension Server that the LS binary connects
|
|
||||||
// to at startup. The real IDE's extension.js runs a ConnectRPC HTTP/1.1 server
|
|
||||||
// using connectNodeAdapter. We replicate that protocol here.
|
|
||||||
//
|
|
||||||
// Protocol details (from extension.js source):
|
|
||||||
// - Transport: HTTP/1.1 on 127.0.0.1 (no TLS)
|
|
||||||
// - Auth: x-codeium-csrf-token header on every request
|
|
||||||
// - Unary request Content-Type: application/proto (binary protobuf, no envelope)
|
|
||||||
// OR application/connect+proto (with 5-byte envelope)
|
|
||||||
// - Unary response Content-Type: application/proto (raw binary protobuf, no envelope)
|
|
||||||
// - Stream request Content-Type: application/connect+proto (with 5-byte envelope)
|
|
||||||
// - Stream response Content-Type: application/connect+proto (envelope-framed messages)
|
|
||||||
//
|
|
||||||
// The LS sends requests with content-type "application/connect+proto" for BOTH
|
|
||||||
// unary and streaming RPCs. ConnectRPC's content-type regex:
|
|
||||||
//
|
|
||||||
// /^application\/(connect\+)?(?:(json)(?:; ?charset=utf-?8)?|(proto))$/i
|
|
||||||
//
|
|
||||||
// If "connect+" prefix is present → stream mode; otherwise → unary mode.
|
|
||||||
// However the LS Go client uses the Connect protocol client which always sends
|
|
||||||
// "application/proto" for unary and "application/connect+proto" for streaming.
|
|
||||||
//
|
|
||||||
// We detect the RPC kind from the URL path and respond accordingly.
|
|
||||||
package lspool
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"log/slog"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"google.golang.org/protobuf/proto"
|
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// Proto helpers — hand-encode minimal proto messages so we don't
|
|
||||||
// need to import the full generated proto package.
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
// encodeProtoString writes a proto string field (wire type 2) to a byte slice.
|
|
||||||
func encodeProtoString(fieldNum int, val string) []byte {
|
|
||||||
tag := encodeVarint(uint64(fieldNum<<3 | 2))
|
|
||||||
length := encodeVarint(uint64(len(val)))
|
|
||||||
out := make([]byte, 0, len(tag)+len(length)+len(val))
|
|
||||||
out = append(out, tag...)
|
|
||||||
out = append(out, length...)
|
|
||||||
out = append(out, []byte(val)...)
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
// encodeProtoBytes writes a proto bytes/message field (wire type 2).
|
|
||||||
func encodeProtoBytes(fieldNum int, val []byte) []byte {
|
|
||||||
tag := encodeVarint(uint64(fieldNum<<3 | 2))
|
|
||||||
length := encodeVarint(uint64(len(val)))
|
|
||||||
out := make([]byte, 0, len(tag)+len(length)+len(val))
|
|
||||||
out = append(out, tag...)
|
|
||||||
out = append(out, length...)
|
|
||||||
out = append(out, val...)
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
// encodeProtoVarint writes a proto varint field (wire type 0).
|
|
||||||
func encodeProtoVarint(fieldNum int, val uint64) []byte {
|
|
||||||
tag := encodeVarint(uint64(fieldNum<<3 | 0))
|
|
||||||
v := encodeVarint(val)
|
|
||||||
out := make([]byte, 0, len(tag)+len(v))
|
|
||||||
out = append(out, tag...)
|
|
||||||
out = append(out, v...)
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
// encodeProtoBool writes a proto bool field.
|
|
||||||
func encodeProtoBool(fieldNum int, val bool) []byte {
|
|
||||||
v := uint64(0)
|
|
||||||
if val {
|
|
||||||
v = 1
|
|
||||||
}
|
|
||||||
return encodeProtoVarint(fieldNum, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
func encodeVarint(v uint64) []byte {
|
|
||||||
buf := make([]byte, binary.MaxVarintLen64)
|
|
||||||
n := binary.PutUvarint(buf, v)
|
|
||||||
return buf[:n]
|
|
||||||
}
|
|
||||||
|
|
||||||
// decodeProtoString extracts a string field from raw proto bytes.
|
|
||||||
func decodeProtoString(data []byte, targetField int) string {
|
|
||||||
i := 0
|
|
||||||
for i < len(data) {
|
|
||||||
if i >= len(data) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
tag, n := binary.Uvarint(data[i:])
|
|
||||||
if n <= 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
i += n
|
|
||||||
fieldNum := int(tag >> 3)
|
|
||||||
wireType := tag & 0x7
|
|
||||||
|
|
||||||
switch wireType {
|
|
||||||
case 0: // varint
|
|
||||||
_, n = binary.Uvarint(data[i:])
|
|
||||||
if n <= 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
i += n
|
|
||||||
case 2: // length-delimited
|
|
||||||
length, n := binary.Uvarint(data[i:])
|
|
||||||
if n <= 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
i += n
|
|
||||||
if fieldNum == targetField {
|
|
||||||
end := i + int(length)
|
|
||||||
if end > len(data) {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return string(data[i:end])
|
|
||||||
}
|
|
||||||
i += int(length)
|
|
||||||
case 1: // 64-bit
|
|
||||||
i += 8
|
|
||||||
case 5: // 32-bit
|
|
||||||
i += 4
|
|
||||||
default:
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// ConnectRPC envelope helpers
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
// connectEnvelope wraps a proto payload in a ConnectRPC streaming envelope:
|
|
||||||
// 1 byte flags + 4 byte big-endian length + payload
|
|
||||||
func connectEnvelope(flags byte, payload []byte) []byte {
|
|
||||||
frame := make([]byte, 5+len(payload))
|
|
||||||
frame[0] = flags
|
|
||||||
binary.BigEndian.PutUint32(frame[1:5], uint32(len(payload)))
|
|
||||||
copy(frame[5:], payload)
|
|
||||||
return frame
|
|
||||||
}
|
|
||||||
|
|
||||||
// connectEndOfStream returns the end-of-stream trailer frame for ConnectRPC.
|
|
||||||
// flags=0x02 signals end of stream. The payload is a JSON object with empty metadata.
|
|
||||||
func connectEndOfStream() []byte {
|
|
||||||
trailer := []byte("{}")
|
|
||||||
return connectEnvelope(0x02, trailer)
|
|
||||||
}
|
|
||||||
|
|
||||||
// unwrapConnectEnvelope strips the 5-byte envelope header from a ConnectRPC message.
|
|
||||||
// Returns the raw proto payload. If the input is shorter than 5 bytes, returns as-is.
|
|
||||||
func unwrapConnectEnvelope(body []byte) []byte {
|
|
||||||
if len(body) < 5 {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
// Check if it looks like an envelope: first byte should be 0x00 or 0x01
|
|
||||||
if body[0] > 0x02 {
|
|
||||||
return body // Not envelope-framed, return raw
|
|
||||||
}
|
|
||||||
plen := binary.BigEndian.Uint32(body[1:5])
|
|
||||||
if int(plen)+5 > len(body) {
|
|
||||||
return body // Length mismatch, return raw
|
|
||||||
}
|
|
||||||
return body[5 : 5+plen]
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// OAuthTokenInfo proto builder
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
// buildOAuthTokenInfoBinary creates binary-encoded OAuthTokenInfo proto.
|
|
||||||
//
|
|
||||||
// message OAuthTokenInfo {
|
|
||||||
// string access_token = 1;
|
|
||||||
// string token_type = 2;
|
|
||||||
// string refresh_token = 3;
|
|
||||||
// google.protobuf.Timestamp expiry = 4;
|
|
||||||
// bool is_gcp_tos = 6;
|
|
||||||
// }
|
|
||||||
func buildOAuthTokenInfoBinary(accessToken, refreshToken string, expiresAt time.Time) []byte {
|
|
||||||
var buf []byte
|
|
||||||
buf = append(buf, encodeProtoString(1, accessToken)...)
|
|
||||||
buf = append(buf, encodeProtoString(2, "Bearer")...)
|
|
||||||
if refreshToken != "" {
|
|
||||||
buf = append(buf, encodeProtoString(3, refreshToken)...)
|
|
||||||
}
|
|
||||||
// Use real expiry if provided, otherwise default to 1 hour from now
|
|
||||||
expiry := expiresAt
|
|
||||||
if expiry.IsZero() {
|
|
||||||
expiry = time.Now().Add(1 * time.Hour)
|
|
||||||
}
|
|
||||||
ts := ×tamppb.Timestamp{
|
|
||||||
Seconds: expiry.Unix(),
|
|
||||||
}
|
|
||||||
tsBytes, _ := proto.Marshal(ts)
|
|
||||||
buf = append(buf, encodeProtoBytes(4, tsBytes)...)
|
|
||||||
buf = append(buf, encodeProtoBool(6, true)...)
|
|
||||||
return buf
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildUSSTopicWithOAuth creates a USS Topic proto with the OAuth token.
|
|
||||||
//
|
|
||||||
// message Topic { map<string, Row> data = 1; }
|
|
||||||
// message Row { string value = 1; int64 e_tag = 2; }
|
|
||||||
//
|
|
||||||
// The key in the map is "oauthTokenInfoSentinelKey" and the Row.value is
|
|
||||||
// base64(toBinary(OAuthTokenInfo)).
|
|
||||||
func buildUSSTopicWithOAuth(accessToken, refreshToken string, expiresAt time.Time) []byte {
|
|
||||||
tokenBin := buildOAuthTokenInfoBinary(accessToken, refreshToken, expiresAt)
|
|
||||||
tokenB64 := base64.StdEncoding.EncodeToString(tokenBin)
|
|
||||||
|
|
||||||
// Row: value=tokenB64 (field 1), e_tag=1 (field 2)
|
|
||||||
var row []byte
|
|
||||||
row = append(row, encodeProtoString(1, tokenB64)...)
|
|
||||||
row = append(row, encodeProtoVarint(2, 1)...)
|
|
||||||
|
|
||||||
// Map entry: key="oauthTokenInfoSentinelKey" (field 1), value=row (field 2)
|
|
||||||
var entry []byte
|
|
||||||
entry = append(entry, encodeProtoString(1, "oauthTokenInfoSentinelKey")...)
|
|
||||||
entry = append(entry, encodeProtoBytes(2, row)...)
|
|
||||||
|
|
||||||
// Topic: data map entries use field 1
|
|
||||||
var topic []byte
|
|
||||||
topic = append(topic, encodeProtoBytes(1, entry)...)
|
|
||||||
|
|
||||||
return topic
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildPrimitiveBoolBinary(val bool) []byte {
|
|
||||||
// Primitive.bool_value is field 13 in the proto definition
|
|
||||||
return encodeProtoBool(13, val)
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildPrimitiveInt32Binary(val int32) []byte {
|
|
||||||
// Primitive.int32_value is field 3 in the proto definition
|
|
||||||
return encodeProtoVarint(3, uint64(uint32(val)))
|
|
||||||
}
|
|
||||||
|
|
||||||
func encodeUSSBinaryValue(value []byte) string {
|
|
||||||
return base64.StdEncoding.EncodeToString(value)
|
|
||||||
}
|
|
||||||
|
|
||||||
func encodeUSSPrimitiveBoolValue(val bool) string {
|
|
||||||
return encodeUSSBinaryValue(buildPrimitiveBoolBinary(val))
|
|
||||||
}
|
|
||||||
|
|
||||||
func encodeUSSPrimitiveInt32Value(val int32) string {
|
|
||||||
return encodeUSSBinaryValue(buildPrimitiveInt32Binary(val))
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildUSSTopicRow(key string, value string) []byte {
|
|
||||||
row := buildUSSRowBinary(value)
|
|
||||||
|
|
||||||
var entry []byte
|
|
||||||
entry = append(entry, encodeProtoString(1, key)...)
|
|
||||||
entry = append(entry, encodeProtoBytes(2, row)...)
|
|
||||||
return entry
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildUSSRowBinary(value string) []byte {
|
|
||||||
var row []byte
|
|
||||||
row = append(row, encodeProtoString(1, value)...)
|
|
||||||
row = append(row, encodeProtoVarint(2, 1)...)
|
|
||||||
return row
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildUSSTopicWithModelCredits(info *ModelCreditsInfo) []byte {
|
|
||||||
if info == nil {
|
|
||||||
info = &ModelCreditsInfo{}
|
|
||||||
}
|
|
||||||
|
|
||||||
minimum := defaultMinimumCreditAmountForUsage
|
|
||||||
if info.MinimumCreditAmountForUsage != nil {
|
|
||||||
minimum = *info.MinimumCreditAmountForUsage
|
|
||||||
}
|
|
||||||
|
|
||||||
entries := make([][]byte, 0, 3)
|
|
||||||
entries = append(entries, buildUSSTopicRow(
|
|
||||||
useAICreditsSentinelKey,
|
|
||||||
encodeUSSPrimitiveBoolValue(info.UseAICredits),
|
|
||||||
))
|
|
||||||
// JS protocol: useAICreditsSentinelKey carries the toggle state.
|
|
||||||
// availableCreditsSentinelKey is only present when credits are enabled.
|
|
||||||
if info.UseAICredits {
|
|
||||||
credits := int32(9999)
|
|
||||||
if info.AvailableCredits != nil {
|
|
||||||
credits = *info.AvailableCredits
|
|
||||||
}
|
|
||||||
entries = append(entries, buildUSSTopicRow(availableCreditsSentinelKey, encodeUSSPrimitiveInt32Value(credits)))
|
|
||||||
}
|
|
||||||
entries = append(entries, buildUSSTopicRow(minimumCreditAmountForUsageKey, encodeUSSPrimitiveInt32Value(minimum)))
|
|
||||||
|
|
||||||
var topic []byte
|
|
||||||
for _, entry := range entries {
|
|
||||||
topic = append(topic, encodeProtoBytes(1, entry)...)
|
|
||||||
}
|
|
||||||
return topic
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildEmptyTopic returns an empty USS Topic proto (for non-oauth topics).
|
|
||||||
func buildEmptyTopic() []byte {
|
|
||||||
return []byte{} // Empty message = no map entries
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// UnifiedStateSyncUpdate builder
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
// buildInitialStateUpdate creates a UnifiedStateSyncUpdate with initial_state set.
|
|
||||||
//
|
|
||||||
// message UnifiedStateSyncUpdate {
|
|
||||||
// oneof update_type {
|
|
||||||
// Topic initial_state = 1;
|
|
||||||
// AppliedUpdate applied_update = 2;
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
func buildInitialStateUpdate(topicData []byte) []byte {
|
|
||||||
return encodeProtoBytes(1, topicData)
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildAppliedUpdate(key string, row []byte) []byte {
|
|
||||||
var applied []byte
|
|
||||||
applied = append(applied, encodeProtoString(1, key)...)
|
|
||||||
if len(row) > 0 {
|
|
||||||
applied = append(applied, encodeProtoBytes(2, row)...)
|
|
||||||
}
|
|
||||||
return encodeProtoBytes(2, applied)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// MockExtensionServer
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
// MockExtensionServer provides a ConnectRPC-compatible HTTP server that the
|
|
||||||
// Language Server binary connects to. It implements just enough of the
|
|
||||||
// ExtensionServerService to keep the LS operational.
|
|
||||||
type MockExtensionServer struct {
|
|
||||||
listener net.Listener
|
|
||||||
server *http.Server
|
|
||||||
port int
|
|
||||||
csrf string
|
|
||||||
mu sync.RWMutex
|
|
||||||
tokens map[string]*TokenInfo // account_id -> token info
|
|
||||||
credits map[string]*ModelCreditsInfo // account_id -> model credits info
|
|
||||||
subscribers map[string]map[int]*stateSubscriber
|
|
||||||
nextSubID int
|
|
||||||
lastAccountID string
|
|
||||||
logger *slog.Logger
|
|
||||||
|
|
||||||
// Trajectory callback — when LS pushes trajectory updates, we forward them
|
|
||||||
onTrajectoryUpdate func(topic, key string, data []byte)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TokenInfo holds OAuth token details for an account.
|
|
||||||
type TokenInfo struct {
|
|
||||||
AccessToken string
|
|
||||||
RefreshToken string
|
|
||||||
ExpiresAt time.Time // zero value means unknown; defaults to now+1h
|
|
||||||
}
|
|
||||||
|
|
||||||
// ModelCreditsInfo mirrors the JS uss-modelCredits topic state.
|
|
||||||
type ModelCreditsInfo struct {
|
|
||||||
UseAICredits bool
|
|
||||||
AvailableCredits *int32
|
|
||||||
MinimumCreditAmountForUsage *int32
|
|
||||||
}
|
|
||||||
|
|
||||||
type stateSubscriber struct {
|
|
||||||
id int
|
|
||||||
accountID string
|
|
||||||
topic string
|
|
||||||
updates chan []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
useAICreditsSentinelKey = "useAICreditsSentinelKey"
|
|
||||||
availableCreditsSentinelKey = "availableCreditsSentinelKey"
|
|
||||||
minimumCreditAmountForUsageKey = "minimumCreditAmountForUsageKey"
|
|
||||||
defaultMinimumCreditAmountForUsage = int32(50)
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewMockExtensionServer creates a mock extension server with proper ConnectRPC handling.
|
|
||||||
func NewMockExtensionServer(csrf string) (*MockExtensionServer, error) {
|
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("listen: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m := &MockExtensionServer{
|
|
||||||
listener: listener,
|
|
||||||
port: listener.Addr().(*net.TCPAddr).Port,
|
|
||||||
csrf: csrf,
|
|
||||||
tokens: make(map[string]*TokenInfo),
|
|
||||||
credits: make(map[string]*ModelCreditsInfo),
|
|
||||||
subscribers: make(map[string]map[int]*stateSubscriber),
|
|
||||||
logger: slog.Default().With("component", "mock-ext-server"),
|
|
||||||
}
|
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
|
||||||
extService := "/exa.extension_server_pb.ExtensionServerService/"
|
|
||||||
|
|
||||||
// Register all RPCs the LS calls on the Extension Server.
|
|
||||||
// Unary RPCs — return application/proto
|
|
||||||
mux.HandleFunc(extService+"LanguageServerStarted", m.handleUnary(m.onLanguageServerStarted))
|
|
||||||
mux.HandleFunc(extService+"Heartbeat", m.handleUnary(m.onHeartbeat))
|
|
||||||
mux.HandleFunc(extService+"GetSecretValue", m.handleUnary(m.onGetSecretValue))
|
|
||||||
mux.HandleFunc(extService+"StoreSecretValue", m.handleUnary(m.onStoreSecretValue))
|
|
||||||
mux.HandleFunc(extService+"IsAgentManagerEnabled", m.handleUnary(m.onIsAgentManagerEnabled))
|
|
||||||
mux.HandleFunc(extService+"PushUnifiedStateSyncUpdate", m.handleUnary(m.onPushUnifiedStateSyncUpdate))
|
|
||||||
mux.HandleFunc(extService+"RecordError", m.handleUnary(m.onRecordError))
|
|
||||||
mux.HandleFunc(extService+"LogEvent", m.handleUnary(m.onLogEvent))
|
|
||||||
mux.HandleFunc(extService+"UpdateCascadeTrajectorySummaries", m.handleUnary(m.onUpdateTrajectorySummaries))
|
|
||||||
mux.HandleFunc(extService+"BroadcastConversationDeletion", m.handleUnary(m.onDefault))
|
|
||||||
mux.HandleFunc(extService+"WriteCascadeEdit", m.handleUnary(m.onDefault))
|
|
||||||
mux.HandleFunc(extService+"OpenDiffZones", m.handleUnary(m.onDefault))
|
|
||||||
mux.HandleFunc(extService+"HandleAsyncPostMessage", m.handleUnary(m.onDefault))
|
|
||||||
mux.HandleFunc(extService+"OpenFilePointer", m.handleUnary(m.onDefault))
|
|
||||||
mux.HandleFunc(extService+"OpenVirtualFile", m.handleUnary(m.onDefault))
|
|
||||||
mux.HandleFunc(extService+"SaveDocument", m.handleUnary(m.onDefault))
|
|
||||||
mux.HandleFunc(extService+"RestartUserStatusUpdater", m.handleUnary(m.onDefault))
|
|
||||||
mux.HandleFunc(extService+"FocusIDEWindow", m.handleUnary(m.onDefault))
|
|
||||||
mux.HandleFunc(extService+"SmartFocusConversation", m.handleUnary(m.onDefault))
|
|
||||||
mux.HandleFunc(extService+"RunExtensionCode", m.handleUnary(m.onDefault))
|
|
||||||
mux.HandleFunc(extService+"UpdateDetailedViewWithCascadeInput", m.handleUnary(m.onDefault))
|
|
||||||
mux.HandleFunc(extService+"FindAllReferences", m.handleUnary(m.onDefault))
|
|
||||||
mux.HandleFunc(extService+"GetDefinition", m.handleUnary(m.onDefault))
|
|
||||||
mux.HandleFunc(extService+"GetLintErrors", m.handleUnary(m.onDefault))
|
|
||||||
|
|
||||||
// Server-streaming RPCs — return application/connect+proto
|
|
||||||
mux.HandleFunc(extService+"SubscribeToUnifiedStateSyncTopic", m.handleStream(m.onSubscribeStateSyncTopic))
|
|
||||||
mux.HandleFunc(extService+"ExecuteCommand", m.handleStream(m.onExecuteCommand))
|
|
||||||
|
|
||||||
// Catch-all for any unregistered RPCs
|
|
||||||
mux.HandleFunc("/", m.handleCatchAll)
|
|
||||||
|
|
||||||
m.server = &http.Server{Handler: mux}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
if err := m.server.Serve(listener); err != http.ErrServerClosed {
|
|
||||||
m.logger.Error("extension server error", "err", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
m.logger.Info("mock extension server started", "port", m.port, "csrf_len", len(csrf))
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Port returns the listening port.
|
|
||||||
func (m *MockExtensionServer) Port() int {
|
|
||||||
return m.port
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetToken sets the OAuth token for an account.
|
|
||||||
func (m *MockExtensionServer) SetToken(accountID string, info *TokenInfo) {
|
|
||||||
m.mu.Lock()
|
|
||||||
m.tokens[accountID] = info
|
|
||||||
m.lastAccountID = accountID
|
|
||||||
subscribers := m.snapshotSubscribersLocked("uss-oauth", accountID)
|
|
||||||
m.mu.Unlock()
|
|
||||||
|
|
||||||
if info == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tokenBin := buildOAuthTokenInfoBinary(info.AccessToken, info.RefreshToken, info.ExpiresAt)
|
|
||||||
tokenB64 := base64.StdEncoding.EncodeToString(tokenBin)
|
|
||||||
m.publishTopicUpdate(subscribers, buildAppliedUpdate("oauthTokenInfoSentinelKey", buildUSSRowBinary(tokenB64)))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetModelCredits sets the uss-modelCredits state for an account.
|
|
||||||
func (m *MockExtensionServer) SetModelCredits(accountID string, info *ModelCreditsInfo) {
|
|
||||||
if info == nil {
|
|
||||||
info = &ModelCreditsInfo{}
|
|
||||||
}
|
|
||||||
copyInfo := *info
|
|
||||||
m.mu.Lock()
|
|
||||||
m.credits[accountID] = ©Info
|
|
||||||
m.lastAccountID = accountID
|
|
||||||
subscribers := m.snapshotSubscribersLocked("uss-modelCredits", accountID)
|
|
||||||
m.mu.Unlock()
|
|
||||||
|
|
||||||
m.publishTopicUpdate(subscribers, buildModelCreditsAppliedUpdates(©Info)...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetTrajectoryCallback registers a callback for when the LS pushes trajectory data.
|
|
||||||
func (m *MockExtensionServer) SetTrajectoryCallback(fn func(topic, key string, data []byte)) {
|
|
||||||
m.onTrajectoryUpdate = fn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockExtensionServer) currentTokenLocked() *TokenInfo {
|
|
||||||
if m.lastAccountID != "" {
|
|
||||||
if info := m.tokens[m.lastAccountID]; info != nil {
|
|
||||||
return info
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, info := range m.tokens {
|
|
||||||
return info
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockExtensionServer) currentModelCreditsLocked() *ModelCreditsInfo {
|
|
||||||
if m.lastAccountID != "" {
|
|
||||||
if info := m.credits[m.lastAccountID]; info != nil {
|
|
||||||
return info
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, info := range m.credits {
|
|
||||||
return info
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockExtensionServer) tokenForAccountLocked(accountID string) *TokenInfo {
|
|
||||||
if accountID != "" {
|
|
||||||
if info := m.tokens[accountID]; info != nil {
|
|
||||||
return info
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return m.currentTokenLocked()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockExtensionServer) creditsForAccountLocked(accountID string) *ModelCreditsInfo {
|
|
||||||
if accountID != "" {
|
|
||||||
if info := m.credits[accountID]; info != nil {
|
|
||||||
return info
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return m.currentModelCreditsLocked()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockExtensionServer) snapshotSubscribersLocked(topic, accountID string) []*stateSubscriber {
|
|
||||||
topicSubs := m.subscribers[topic]
|
|
||||||
if len(topicSubs) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
out := make([]*stateSubscriber, 0, len(topicSubs))
|
|
||||||
for _, sub := range topicSubs {
|
|
||||||
if sub == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if accountID != "" && sub.accountID != "" && sub.accountID != accountID {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
out = append(out, sub)
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockExtensionServer) publishTopicUpdate(subscribers []*stateSubscriber, updates ...[]byte) {
|
|
||||||
for _, sub := range subscribers {
|
|
||||||
if sub == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, update := range updates {
|
|
||||||
if len(update) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
payload := append([]byte(nil), update...)
|
|
||||||
select {
|
|
||||||
case sub.updates <- payload:
|
|
||||||
default:
|
|
||||||
m.logger.Warn("dropping USS update", "topic", sub.topic, "account", sub.accountID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildModelCreditsAppliedUpdates(info *ModelCreditsInfo) [][]byte {
|
|
||||||
if info == nil {
|
|
||||||
info = &ModelCreditsInfo{}
|
|
||||||
}
|
|
||||||
minimum := defaultMinimumCreditAmountForUsage
|
|
||||||
if info.MinimumCreditAmountForUsage != nil {
|
|
||||||
minimum = *info.MinimumCreditAmountForUsage
|
|
||||||
}
|
|
||||||
|
|
||||||
updates := make([][]byte, 0, 3)
|
|
||||||
updates = append(updates, buildAppliedUpdate(
|
|
||||||
useAICreditsSentinelKey,
|
|
||||||
buildUSSRowBinary(encodeUSSPrimitiveBoolValue(info.UseAICredits)),
|
|
||||||
))
|
|
||||||
|
|
||||||
if info.UseAICredits {
|
|
||||||
credits := int32(9999)
|
|
||||||
if info.AvailableCredits != nil {
|
|
||||||
credits = *info.AvailableCredits
|
|
||||||
}
|
|
||||||
updates = append(updates, buildAppliedUpdate(
|
|
||||||
availableCreditsSentinelKey,
|
|
||||||
buildUSSRowBinary(encodeUSSPrimitiveInt32Value(credits)),
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
updates = append(updates, buildAppliedUpdate(availableCreditsSentinelKey, nil))
|
|
||||||
}
|
|
||||||
updates = append(updates, buildAppliedUpdate(
|
|
||||||
minimumCreditAmountForUsageKey,
|
|
||||||
buildUSSRowBinary(encodeUSSPrimitiveInt32Value(minimum)),
|
|
||||||
))
|
|
||||||
|
|
||||||
return updates
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close shuts down the server.
|
|
||||||
func (m *MockExtensionServer) Close() error {
|
|
||||||
return m.server.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// Middleware
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
type unaryHandler func(body []byte) []byte
|
|
||||||
type streamHandler func(body []byte, w http.ResponseWriter, r *http.Request)
|
|
||||||
|
|
||||||
// handleUnary wraps a unary RPC handler with CSRF check and proper content-type.
|
|
||||||
func (m *MockExtensionServer) handleUnary(handler unaryHandler) http.HandlerFunc {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// CSRF check
|
|
||||||
if !m.checkCSRF(w, r) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := io.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
m.logger.Error("read body", "err", err, "path", r.URL.Path)
|
|
||||||
w.Header().Set("Content-Type", "application/proto")
|
|
||||||
w.WriteHeader(200)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// The LS might send with envelope framing (application/connect+proto)
|
|
||||||
// or without (application/proto). Detect and unwrap.
|
|
||||||
ct := r.Header.Get("Content-Type")
|
|
||||||
protoBody := body
|
|
||||||
if strings.Contains(ct, "connect+proto") && len(body) >= 5 {
|
|
||||||
protoBody = unwrapConnectEnvelope(body)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.logger.Debug("unary RPC", "path", r.URL.Path, "body_len", len(protoBody), "content_type", ct)
|
|
||||||
|
|
||||||
responseProto := handler(protoBody)
|
|
||||||
|
|
||||||
// Respond with proper unary ConnectRPC content-type.
|
|
||||||
// If the request used "connect+proto", the response should be "application/proto"
|
|
||||||
// for unary RPCs (ConnectRPC spec: unary uses application/proto, not connect+proto).
|
|
||||||
w.Header().Set("Content-Type", "application/proto")
|
|
||||||
w.WriteHeader(200)
|
|
||||||
if len(responseProto) > 0 {
|
|
||||||
w.Write(responseProto)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleStream wraps a server-streaming RPC handler with CSRF and content-type.
|
|
||||||
func (m *MockExtensionServer) handleStream(handler streamHandler) http.HandlerFunc {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if !m.checkCSRF(w, r) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := io.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
m.logger.Error("read body", "err", err, "path", r.URL.Path)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unwrap envelope from request
|
|
||||||
ct := r.Header.Get("Content-Type")
|
|
||||||
if strings.Contains(ct, "connect+proto") || strings.Contains(ct, "connect+json") {
|
|
||||||
body = unwrapConnectEnvelope(body)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.logger.Debug("stream RPC", "path", r.URL.Path, "body_len", len(body))
|
|
||||||
|
|
||||||
// Set streaming response content-type
|
|
||||||
w.Header().Set("Content-Type", "application/connect+proto")
|
|
||||||
w.WriteHeader(200)
|
|
||||||
|
|
||||||
handler(body, w, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockExtensionServer) checkCSRF(w http.ResponseWriter, r *http.Request) bool {
|
|
||||||
token := r.Header.Get("x-codeium-csrf-token")
|
|
||||||
if m.csrf != "" && token != m.csrf {
|
|
||||||
m.logger.Warn("CSRF mismatch", "path", r.URL.Path, "got", token[:min(8, len(token))])
|
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
|
||||||
w.WriteHeader(403)
|
|
||||||
w.Write([]byte("Invalid CSRF token"))
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func min(a, b int) int {
|
|
||||||
if a < b {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// Unary RPC Handlers — each receives raw proto request body,
|
|
||||||
// returns raw proto response body.
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
func (m *MockExtensionServer) onLanguageServerStarted(body []byte) []byte {
|
|
||||||
// LanguageServerStartedRequest has: https_port(1), http_port(2), lsp_port(3), csrf_token(4)
|
|
||||||
// We just log the ports — they're informational.
|
|
||||||
m.logger.Info("LanguageServerStarted",
|
|
||||||
"body_len", len(body))
|
|
||||||
// Return empty LanguageServerStartedResponse
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockExtensionServer) onHeartbeat(body []byte) []byte {
|
|
||||||
// Return empty HeartbeatResponse
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockExtensionServer) onGetSecretValue(body []byte) []byte {
|
|
||||||
// GetSecretValueRequest: key = field 1
|
|
||||||
key := decodeProtoString(body, 1)
|
|
||||||
m.logger.Debug("GetSecretValue", "key", key)
|
|
||||||
|
|
||||||
m.mu.RLock()
|
|
||||||
var token string
|
|
||||||
if info := m.currentTokenLocked(); info != nil {
|
|
||||||
token = info.AccessToken
|
|
||||||
}
|
|
||||||
m.mu.RUnlock()
|
|
||||||
|
|
||||||
// GetSecretValueResponse: value = field 1
|
|
||||||
if token != "" {
|
|
||||||
return encodeProtoString(1, token)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockExtensionServer) onStoreSecretValue(body []byte) []byte {
|
|
||||||
key := decodeProtoString(body, 1)
|
|
||||||
m.logger.Debug("StoreSecretValue", "key", key)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockExtensionServer) onIsAgentManagerEnabled(body []byte) []byte {
|
|
||||||
// IsAgentManagerEnabledResponse: enabled = field 1 (bool)
|
|
||||||
return encodeProtoBool(1, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockExtensionServer) onPushUnifiedStateSyncUpdate(body []byte) []byte {
|
|
||||||
// PushUnifiedStateSyncUpdateRequest: update = field 1 (UpdateRequest message)
|
|
||||||
// UpdateRequest: topic_name = field 1, applied_update = field 5, key = field 2
|
|
||||||
m.logger.Debug("PushUnifiedStateSyncUpdate", "body_len", len(body))
|
|
||||||
|
|
||||||
// Extract topic name from the embedded UpdateRequest
|
|
||||||
// The body is PushUnifiedStateSyncUpdateRequest, field 1 is the UpdateRequest
|
|
||||||
// We need to dig into the nested message to get topic_name
|
|
||||||
if m.onTrajectoryUpdate != nil {
|
|
||||||
// For now, just notify that an update was pushed
|
|
||||||
m.onTrajectoryUpdate("", "", body)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return empty PushUnifiedStateSyncUpdateResponse
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockExtensionServer) onRecordError(body []byte) []byte {
|
|
||||||
m.logger.Debug("RecordError", "body_len", len(body))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockExtensionServer) onLogEvent(body []byte) []byte {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockExtensionServer) onUpdateTrajectorySummaries(body []byte) []byte {
|
|
||||||
m.logger.Debug("UpdateCascadeTrajectorySummaries", "body_len", len(body))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockExtensionServer) onDefault(body []byte) []byte {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// Streaming RPC Handlers
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
func (m *MockExtensionServer) onSubscribeStateSyncTopic(body []byte, w http.ResponseWriter, r *http.Request) {
|
|
||||||
// SubscribeToUnifiedStateSyncTopicRequest: topic = field 1
|
|
||||||
topic := decodeProtoString(body, 1)
|
|
||||||
m.logger.Info("SubscribeToUnifiedStateSyncTopic", "topic", topic)
|
|
||||||
|
|
||||||
flusher, ok := w.(http.Flusher)
|
|
||||||
if !ok {
|
|
||||||
m.logger.Error("ResponseWriter does not support Flush")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mu.Lock()
|
|
||||||
accountID := m.lastAccountID
|
|
||||||
subID := m.nextSubID
|
|
||||||
m.nextSubID++
|
|
||||||
sub := &stateSubscriber{
|
|
||||||
id: subID,
|
|
||||||
accountID: accountID,
|
|
||||||
topic: topic,
|
|
||||||
updates: make(chan []byte, 16),
|
|
||||||
}
|
|
||||||
if m.subscribers[topic] == nil {
|
|
||||||
m.subscribers[topic] = make(map[int]*stateSubscriber)
|
|
||||||
}
|
|
||||||
m.subscribers[topic][subID] = sub
|
|
||||||
|
|
||||||
// Build initial state based on topic
|
|
||||||
var topicData []byte
|
|
||||||
switch topic {
|
|
||||||
case "uss-oauth":
|
|
||||||
tokenInfo := m.tokenForAccountLocked(accountID)
|
|
||||||
if tokenInfo != nil {
|
|
||||||
topicData = buildUSSTopicWithOAuth(tokenInfo.AccessToken, tokenInfo.RefreshToken, tokenInfo.ExpiresAt)
|
|
||||||
} else {
|
|
||||||
topicData = buildEmptyTopic()
|
|
||||||
}
|
|
||||||
case "uss-modelCredits":
|
|
||||||
creditsInfo := m.creditsForAccountLocked(accountID)
|
|
||||||
if creditsInfo != nil {
|
|
||||||
topicData = buildUSSTopicWithModelCredits(creditsInfo)
|
|
||||||
} else {
|
|
||||||
topicData = buildEmptyTopic()
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
// For all other topics (browserPreferences, enterprisePreferences, etc.),
|
|
||||||
// return empty topic data.
|
|
||||||
topicData = buildEmptyTopic()
|
|
||||||
}
|
|
||||||
m.mu.Unlock()
|
|
||||||
defer func() {
|
|
||||||
m.mu.Lock()
|
|
||||||
if topicSubs := m.subscribers[topic]; topicSubs != nil {
|
|
||||||
delete(topicSubs, subID)
|
|
||||||
if len(topicSubs) == 0 {
|
|
||||||
delete(m.subscribers, topic)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
m.mu.Unlock()
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Send initial state as envelope-framed message
|
|
||||||
initialUpdate := buildInitialStateUpdate(topicData)
|
|
||||||
frame := connectEnvelope(0x00, initialUpdate)
|
|
||||||
w.Write(frame)
|
|
||||||
flusher.Flush()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-r.Context().Done():
|
|
||||||
m.logger.Debug("SubscribeToUnifiedStateSyncTopic stream closed", "topic", topic)
|
|
||||||
return
|
|
||||||
case update := <-sub.updates:
|
|
||||||
if len(update) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if _, err := w.Write(connectEnvelope(0x00, update)); err != nil {
|
|
||||||
m.logger.Debug("SubscribeToUnifiedStateSyncTopic write failed", "topic", topic, "err", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockExtensionServer) onExecuteCommand(body []byte, w http.ResponseWriter, r *http.Request) {
|
|
||||||
m.logger.Debug("ExecuteCommand (mock)", "body_len", len(body))
|
|
||||||
// Send end-of-stream immediately — we don't execute commands
|
|
||||||
flusher, ok := w.(http.Flusher)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.Write(connectEndOfStream())
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// Catch-all handler
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
func (m *MockExtensionServer) handleCatchAll(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if !m.checkCSRF(w, r) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
m.logger.Debug("unhandled RPC (returning empty proto)", "path", r.URL.Path, "method", r.Method)
|
|
||||||
|
|
||||||
// Drain request body
|
|
||||||
io.ReadAll(r.Body)
|
|
||||||
|
|
||||||
// Determine if this is likely a unary or streaming request based on content-type.
|
|
||||||
ct := r.Header.Get("Content-Type")
|
|
||||||
if strings.Contains(ct, "connect+") {
|
|
||||||
// Could be streaming — respond with unary proto to be safe
|
|
||||||
// (unary Connect requests can also use connect+ prefix in some client impls)
|
|
||||||
w.Header().Set("Content-Type", "application/proto")
|
|
||||||
} else {
|
|
||||||
w.Header().Set("Content-Type", "application/proto")
|
|
||||||
}
|
|
||||||
w.WriteHeader(200)
|
|
||||||
}
|
|
||||||
File diff suppressed because it is too large
Load Diff
@ -1,376 +0,0 @@
|
|||||||
package lspool
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestBuildLSEnvKeepsExistingSSLValues(t *testing.T) {
|
|
||||||
env := buildLSEnv([]string{
|
|
||||||
"SSL_CERT_FILE=/custom/ca.pem",
|
|
||||||
"SSL_CERT_DIR=/custom/certs",
|
|
||||||
}, "/opt/antigravity", "")
|
|
||||||
require.Contains(t, env, "ANTIGRAVITY_EDITOR_APP_ROOT=/opt/antigravity")
|
|
||||||
require.Contains(t, env, "SSL_CERT_FILE=/custom/ca.pem")
|
|
||||||
require.Contains(t, env, "SSL_CERT_DIR=/custom/certs")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildLSEnvClearsInheritedProxyWhenUnset(t *testing.T) {
|
|
||||||
env := buildLSEnv([]string{
|
|
||||||
"HTTPS_PROXY=http://old-proxy:8080",
|
|
||||||
"HTTP_PROXY=http://old-proxy:8080",
|
|
||||||
"ALL_PROXY=socks5://old-proxy:1080",
|
|
||||||
"https_proxy=http://old-proxy:8080",
|
|
||||||
"http_proxy=http://old-proxy:8080",
|
|
||||||
"all_proxy=socks5://old-proxy:1080",
|
|
||||||
}, "/opt/antigravity", "")
|
|
||||||
|
|
||||||
require.Contains(t, env, "HTTPS_PROXY=")
|
|
||||||
require.Contains(t, env, "HTTP_PROXY=")
|
|
||||||
require.Contains(t, env, "ALL_PROXY=")
|
|
||||||
require.Contains(t, env, "https_proxy=")
|
|
||||||
require.Contains(t, env, "http_proxy=")
|
|
||||||
require.Contains(t, env, "all_proxy=")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestShortAccountID(t *testing.T) {
|
|
||||||
require.Equal(t, "9", shortAccountID("9"))
|
|
||||||
require.Equal(t, "12345678", shortAccountID("12345678"))
|
|
||||||
require.Equal(t, "12345678", shortAccountID("123456789"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFrameConnectMessage(t *testing.T) {
|
|
||||||
framed := frameConnectMessage([]byte(`{"x":1}`))
|
|
||||||
require.Len(t, framed, 5+len(`{"x":1}`))
|
|
||||||
require.Equal(t, byte(0), framed[0])
|
|
||||||
require.Equal(t, uint32(len(`{"x":1}`)), binary.BigEndian.Uint32(framed[1:5]))
|
|
||||||
require.Equal(t, `{"x":1}`, string(framed[5:]))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnectEnvelope(t *testing.T) {
|
|
||||||
payload := []byte("hello")
|
|
||||||
env := connectEnvelope(0x00, payload)
|
|
||||||
require.Len(t, env, 5+len(payload))
|
|
||||||
require.Equal(t, byte(0x00), env[0])
|
|
||||||
require.Equal(t, uint32(5), binary.BigEndian.Uint32(env[1:5]))
|
|
||||||
require.Equal(t, "hello", string(env[5:]))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUnwrapConnectEnvelope(t *testing.T) {
|
|
||||||
payload := []byte("test data")
|
|
||||||
env := connectEnvelope(0x00, payload)
|
|
||||||
unwrapped := unwrapConnectEnvelope(env)
|
|
||||||
require.Equal(t, payload, unwrapped)
|
|
||||||
short := []byte{1, 2}
|
|
||||||
require.Equal(t, short, unwrapConnectEnvelope(short))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExtractPromptAndModel(t *testing.T) {
|
|
||||||
body := `{"model":"gemini-2.5-pro","project":"p","request":{"contents":[{"role":"user","parts":[{"text":"hello world"}]}]}}`
|
|
||||||
prompt, model := extractPromptAndModel([]byte(body))
|
|
||||||
require.Equal(t, "hello world", prompt)
|
|
||||||
require.Equal(t, "gemini-2.5-pro", model)
|
|
||||||
|
|
||||||
body2 := `{"contents":[{"role":"user","parts":[{"text":"test prompt"}]}]}`
|
|
||||||
prompt2, _ := extractPromptAndModel([]byte(body2))
|
|
||||||
require.Equal(t, "test prompt", prompt2)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolveModelEnum(t *testing.T) {
|
|
||||||
// Without dynamic mapping loaded, should return fallback (312 = gemini-2.5-flash)
|
|
||||||
require.True(t, resolveModelEnum("gemini-2.5-flash") > 0)
|
|
||||||
require.True(t, resolveModelEnum("models/gemini-2.5-flash") > 0)
|
|
||||||
require.True(t, resolveModelEnum("claude-sonnet-4-6") > 0)
|
|
||||||
require.True(t, resolveModelEnum("unknown-model") > 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildCascadeConfigIncludesRequestedModel(t *testing.T) {
|
|
||||||
cfg := buildCascadeConfig("models/gemini-2.5-flash")
|
|
||||||
require.NotNil(t, cfg)
|
|
||||||
|
|
||||||
plannerConfig, ok := cfg["plannerConfig"].(map[string]any)
|
|
||||||
require.True(t, ok)
|
|
||||||
requestedModel, ok := plannerConfig["requestedModel"].(map[string]any)
|
|
||||||
require.True(t, ok)
|
|
||||||
require.NotEmpty(t, requestedModel["model"])
|
|
||||||
require.Len(t, plannerConfig, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildCascadeConfigClaudeIncludesRequestedModel(t *testing.T) {
|
|
||||||
cfg := buildCascadeConfig("claude-sonnet-4-6")
|
|
||||||
require.NotNil(t, cfg)
|
|
||||||
|
|
||||||
plannerConfig, ok := cfg["plannerConfig"].(map[string]any)
|
|
||||||
require.True(t, ok)
|
|
||||||
requestedModel, ok := plannerConfig["requestedModel"].(map[string]any)
|
|
||||||
require.True(t, ok)
|
|
||||||
require.NotEmpty(t, requestedModel["model"])
|
|
||||||
require.Len(t, plannerConfig, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDoNonStreamGeneratePassesThrough(t *testing.T) {
|
|
||||||
fallback := &recordingUpstream{}
|
|
||||||
upstream := NewLSPoolUpstream(&Pool{}, fallback)
|
|
||||||
req, _ := http.NewRequest("POST", "https://example.com/v1beta/models/gemini:generateContent", bytes.NewReader([]byte(`{}`)))
|
|
||||||
resp, err := upstream.Do(req, "", 1, 1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, resp)
|
|
||||||
require.Equal(t, 1, fallback.doCalls)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExtractPlannerResponseText(t *testing.T) {
|
|
||||||
resp := `{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[
|
|
||||||
{"type":"CORTEX_STEP_TYPE_USER_INPUT","status":"CORTEX_STEP_STATUS_DONE"},
|
|
||||||
{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE",
|
|
||||||
"plannerResponse":{"response":"Hello world"}}
|
|
||||||
]}}`
|
|
||||||
text, generating, status := extractPlannerResponseText([]byte(resp))
|
|
||||||
require.Equal(t, "Hello world", text)
|
|
||||||
require.False(t, generating)
|
|
||||||
require.Equal(t, "CASCADE_RUN_STATUS_IDLE", status)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExtractPlannerResponseState_ErrorDetails(t *testing.T) {
|
|
||||||
resp := `{
|
|
||||||
"status":"CASCADE_RUN_STATUS_IDLE",
|
|
||||||
"trajectory":{
|
|
||||||
"steps":[
|
|
||||||
{"type":"CORTEX_STEP_TYPE_USER_INPUT","status":"CORTEX_STEP_STATUS_DONE"}
|
|
||||||
],
|
|
||||||
"executorMetadata":{
|
|
||||||
"terminationReason":"ERROR",
|
|
||||||
"errorDetails":{
|
|
||||||
"errorCode":429,
|
|
||||||
"shortError":"Model quota reached",
|
|
||||||
"details":"You have exhausted your capacity on this model. Your quota will reset after 1h59m40s."
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}`
|
|
||||||
|
|
||||||
state := extractPlannerResponseState([]byte(resp))
|
|
||||||
require.Equal(t, "CASCADE_RUN_STATUS_IDLE", state.Status)
|
|
||||||
require.False(t, state.Generating)
|
|
||||||
require.Empty(t, state.Text)
|
|
||||||
require.Contains(t, state.ErrorMessage, "Model quota reached")
|
|
||||||
require.Contains(t, state.ErrorMessage, "quota will reset after")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildGeminiSSEChunk(t *testing.T) {
|
|
||||||
sse := buildGeminiSSEChunk("hello")
|
|
||||||
require.Contains(t, sse, "data: ")
|
|
||||||
require.Contains(t, sse, `"text":"hello"`)
|
|
||||||
require.Contains(t, sse, `"role":"model"`)
|
|
||||||
require.True(t, strings.HasSuffix(sse, "\n\n"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRequestHasTools(t *testing.T) {
|
|
||||||
// Wrapped format with tools
|
|
||||||
require.True(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[],"tools":[{"functionDeclarations":[{"name":"get_weather"}]}]}}`)))
|
|
||||||
|
|
||||||
// Direct format with tools
|
|
||||||
require.True(t, requestHasTools([]byte(`{"contents":[],"tools":[{"functionDeclarations":[{"name":"f"}]}]}`)))
|
|
||||||
|
|
||||||
// No tools
|
|
||||||
require.False(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`)))
|
|
||||||
|
|
||||||
// Empty tools array
|
|
||||||
require.False(t, requestHasTools([]byte(`{"contents":[],"tools":[]}`)))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCurrentLSStrategy(t *testing.T) {
|
|
||||||
t.Setenv("ANTIGRAVITY_LS_STRATEGY", "js-parity")
|
|
||||||
require.Equal(t, LSStrategyJSParity, CurrentLSStrategy())
|
|
||||||
|
|
||||||
t.Setenv("ANTIGRAVITY_LS_STRATEGY", "unknown")
|
|
||||||
require.Equal(t, LSStrategyDirect, CurrentLSStrategy())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsPermanentModelMappingError(t *testing.T) {
|
|
||||||
require.True(t, isPermanentModelMappingError(errors.New(`oauth2: "unauthorized_client" "Unauthorized"`)))
|
|
||||||
require.False(t, isPermanentModelMappingError(errors.New("context deadline exceeded")))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPoolSetAccountTokenClearsModelMappingUnavailable(t *testing.T) {
|
|
||||||
pool := &Pool{
|
|
||||||
instances: map[string][]*Instance{
|
|
||||||
"9": {
|
|
||||||
{AccountID: "9", Replica: 0},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
inst := pool.instances["9"][0]
|
|
||||||
inst.SetModelMappingReady(true)
|
|
||||||
inst.SetModelMappingUnavailable(`oauth2: "unauthorized_client" "Unauthorized"`)
|
|
||||||
|
|
||||||
pool.SetAccountToken("9", "ya29.new", "refresh", time.Now().Add(time.Hour))
|
|
||||||
|
|
||||||
require.False(t, inst.HasModelMappingReady())
|
|
||||||
require.False(t, inst.HasModelMappingUnavailable())
|
|
||||||
require.Empty(t, inst.ModelMappingUnavailableReason())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestShouldFallbackDirectForModelMappingUnavailable(t *testing.T) {
|
|
||||||
require.True(t, shouldFallbackDirect(fmt.Errorf("%w: oauth2 unauthorized_client", errLSModelMapDenied)))
|
|
||||||
require.False(t, shouldFallbackDirect(errLSModelMapPending))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseLSReplicaCountDefaultAndEnv(t *testing.T) {
|
|
||||||
t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "")
|
|
||||||
require.Equal(t, 5, parseLSReplicaCount())
|
|
||||||
|
|
||||||
t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "3")
|
|
||||||
require.Equal(t, 3, parseLSReplicaCount())
|
|
||||||
|
|
||||||
t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "0")
|
|
||||||
require.Equal(t, 5, parseLSReplicaCount())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPoolGetUsesStickyReplicaSlot(t *testing.T) {
|
|
||||||
pool := &Pool{
|
|
||||||
config: Config{ReplicasPerAccount: 5},
|
|
||||||
instances: map[string][]*Instance{
|
|
||||||
"acc-1": {
|
|
||||||
{AccountID: "acc-1", Replica: 0, healthy: true},
|
|
||||||
{AccountID: "acc-1", Replica: 1, healthy: true},
|
|
||||||
{AccountID: "acc-1", Replica: 2, healthy: true},
|
|
||||||
{AccountID: "acc-1", Replica: 3, healthy: true},
|
|
||||||
{AccountID: "acc-1", Replica: 4, healthy: true},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
routingKey := "acc-1:user-a:session-1"
|
|
||||||
slot := replicaSlotIndex(routingKey, pool.replicaCount())
|
|
||||||
inst := pool.Get("acc-1", routingKey)
|
|
||||||
require.NotNil(t, inst)
|
|
||||||
require.Equal(t, slot, inst.Replica)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPoolGetWithoutRoutingKeyPrefersLeastBusyReplica(t *testing.T) {
|
|
||||||
busy := &Instance{AccountID: "acc-1", Replica: 0, healthy: true}
|
|
||||||
atomic.StoreInt64(&busy.inflight, 4)
|
|
||||||
idle := &Instance{AccountID: "acc-1", Replica: 1, healthy: true}
|
|
||||||
atomic.StoreInt64(&idle.inflight, 1)
|
|
||||||
|
|
||||||
pool := &Pool{
|
|
||||||
config: Config{ReplicasPerAccount: 5},
|
|
||||||
instances: map[string][]*Instance{
|
|
||||||
"acc-1": {busy, idle},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
inst := pool.Get("acc-1", "")
|
|
||||||
require.NotNil(t, inst)
|
|
||||||
require.Equal(t, 1, inst.Replica)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWaitForInstanceReadyProbesImmediately(t *testing.T) {
|
|
||||||
startedAt := time.Now()
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
attempts, err := waitForInstanceReady(ctx, 200*time.Millisecond, func(context.Context) error {
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, 1, attempts)
|
|
||||||
require.Less(t, time.Since(startedAt), 100*time.Millisecond)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWaitForInstanceReadyRetriesUntilSuccess(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
calls := 0
|
|
||||||
attempts, err := waitForInstanceReady(ctx, 10*time.Millisecond, func(context.Context) error {
|
|
||||||
calls++
|
|
||||||
if calls < 3 {
|
|
||||||
return errors.New("not ready")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, 3, attempts)
|
|
||||||
require.Equal(t, 3, calls)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDecideJSParityRoute(t *testing.T) {
|
|
||||||
body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"s1","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
|
|
||||||
parsed, err := parseGeminiRequest(body)
|
|
||||||
require.NoError(t, err)
|
|
||||||
decision := decideJSParityRoute(parsed, body)
|
|
||||||
require.True(t, decision.UseLS)
|
|
||||||
|
|
||||||
imageBody := []byte(`{"model":"gemini-2.5-flash-image","request":{"sessionId":"s1","contents":[{"role":"user","parts":[{"text":"draw"}]}],"generationConfig":{"responseModalities":["TEXT","IMAGE"]}}}`)
|
|
||||||
parsedImage, err := parseGeminiRequest(imageBody)
|
|
||||||
require.NoError(t, err)
|
|
||||||
decisionImage := decideJSParityRoute(parsedImage, imageBody)
|
|
||||||
require.False(t, decisionImage.UseLS)
|
|
||||||
require.Contains(t, strings.ToLower(decisionImage.Reason), "image")
|
|
||||||
|
|
||||||
noSessionBody := []byte(`{"model":"gemini-2.5-flash","request":{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
|
|
||||||
parsedNoSession, err := parseGeminiRequest(noSessionBody)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.False(t, decideJSParityRoute(parsedNoSession, noSessionBody).UseLS)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUserNamespacePrefersExplicitHeader(t *testing.T) {
|
|
||||||
req, err := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
req.Header.Set(userNamespaceHeader, "tenant-a")
|
|
||||||
req.Header.Set("Authorization", "Bearer oauth-token")
|
|
||||||
|
|
||||||
nsWithExplicit := userNamespace(req)
|
|
||||||
require.NotEqual(t, "anon", nsWithExplicit)
|
|
||||||
|
|
||||||
req.Header.Del(userNamespaceHeader)
|
|
||||||
nsWithAuth := userNamespace(req)
|
|
||||||
require.NotEqual(t, "anon", nsWithAuth)
|
|
||||||
require.NotEqual(t, nsWithExplicit, nsWithAuth)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConversationPrefixEqual(t *testing.T) {
|
|
||||||
prefix := []geminiConversationTurn{
|
|
||||||
{Role: "user", Parts: []geminiConversationPart{{Kind: "text", Text: "hello"}}},
|
|
||||||
{Role: "model", Parts: []geminiConversationPart{{Kind: "text", Text: "world"}}},
|
|
||||||
}
|
|
||||||
full := append(cloneConversationTurns(prefix), geminiConversationTurn{
|
|
||||||
Role: "user",
|
|
||||||
Parts: []geminiConversationPart{{Kind: "text", Text: "follow up"}},
|
|
||||||
})
|
|
||||||
require.True(t, conversationPrefixEqual(full, prefix))
|
|
||||||
require.False(t, conversationPrefixEqual(prefix, full))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSystemTextCompatible(t *testing.T) {
|
|
||||||
require.True(t, systemTextCompatible("You are helpful", ""))
|
|
||||||
require.True(t, systemTextCompatible("You are helpful", "You are helpful"))
|
|
||||||
require.False(t, systemTextCompatible("", "You are helpful"))
|
|
||||||
require.False(t, systemTextCompatible("You are helpful", "You are different"))
|
|
||||||
}
|
|
||||||
|
|
||||||
type recordingUpstream struct {
|
|
||||||
doCalls int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *recordingUpstream) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
|
||||||
r.doCalls++
|
|
||||||
return &http.Response{StatusCode: 200, Body: io.NopCloser(bytes.NewBufferString("ok")), Header: make(http.Header), Request: req}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *recordingUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, c int, _ *tlsfingerprint.Profile) (*http.Response, error) {
|
|
||||||
return r.Do(req, proxyURL, accountID, c)
|
|
||||||
}
|
|
||||||
@ -1,268 +0,0 @@
|
|||||||
package lspool
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"log/slog"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
|
|
||||||
"golang.org/x/net/proxy"
|
|
||||||
)
|
|
||||||
|
|
||||||
type lsProxyBridge struct {
|
|
||||||
listener net.Listener
|
|
||||||
server *http.Server
|
|
||||||
url string
|
|
||||||
upstream string
|
|
||||||
}
|
|
||||||
|
|
||||||
type lsProxyBridgeManager struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
bridges map[string]*lsProxyBridge
|
|
||||||
logger *slog.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
var globalLSProxyBridgeManager = &lsProxyBridgeManager{
|
|
||||||
bridges: make(map[string]*lsProxyBridge),
|
|
||||||
logger: slog.Default().With("component", "lspool-proxy-bridge"),
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
lsProxyBridgeDialTimeout = 10 * time.Second
|
|
||||||
lsProxyBridgeProbeTargets = []string{
|
|
||||||
"cloudcode-pa.googleapis.com:443",
|
|
||||||
"oauthaccountmanager.googleapis.com:443",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
func prepareLSProxyURL(raw string) (string, error) {
|
|
||||||
normalized, parsed, err := proxyurl.Parse(raw)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if parsed == nil {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
switch strings.ToLower(parsed.Scheme) {
|
|
||||||
case "http", "https":
|
|
||||||
return normalized, nil
|
|
||||||
case "socks5", "socks5h":
|
|
||||||
return globalLSProxyBridgeManager.ensure(normalized, parsed)
|
|
||||||
default:
|
|
||||||
return "", fmt.Errorf("unsupported LS proxy scheme: %s", parsed.Scheme)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *lsProxyBridgeManager) ensure(key string, upstream *url.URL) (string, error) {
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
|
|
||||||
if bridge := m.bridges[key]; bridge != nil {
|
|
||||||
return bridge.url, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
bridge, err := newLSProxyBridge(upstream, m.logger)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
m.bridges[key] = bridge
|
|
||||||
return bridge.url, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *lsProxyBridgeManager) closeAll() {
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
|
|
||||||
for key, bridge := range m.bridges {
|
|
||||||
if bridge != nil {
|
|
||||||
_ = bridge.server.Close()
|
|
||||||
_ = bridge.listener.Close()
|
|
||||||
}
|
|
||||||
delete(m.bridges, key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func closeAllLSProxyBridgesForTest() {
|
|
||||||
globalLSProxyBridgeManager.closeAll()
|
|
||||||
}
|
|
||||||
|
|
||||||
func newLSProxyBridge(upstream *url.URL, logger *slog.Logger) (*lsProxyBridge, error) {
|
|
||||||
dialer, err := proxy.FromURL(upstream, proxy.Direct)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("create SOCKS dialer: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("listen LS proxy bridge: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
bridge := &lsProxyBridge{
|
|
||||||
listener: listener,
|
|
||||||
url: "http://" + listener.Addr().String(),
|
|
||||||
upstream: upstream.Redacted(),
|
|
||||||
}
|
|
||||||
|
|
||||||
server := &http.Server{
|
|
||||||
Handler: http.HandlerFunc(bridge.connectHandler(dialer, logger)),
|
|
||||||
ReadHeaderTimeout: 10 * time.Second,
|
|
||||||
IdleTimeout: 2 * time.Minute,
|
|
||||||
}
|
|
||||||
bridge.server = server
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
|
||||||
logger.Error("LS proxy bridge serve failed", "upstream", bridge.upstream, "err", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
logger.Info("LS proxy bridge started", "upstream", bridge.upstream, "listen", bridge.url)
|
|
||||||
go bridge.probeConnectivity(dialer, logger)
|
|
||||||
return bridge, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *lsProxyBridge) connectHandler(dialer proxy.Dialer, logger *slog.Logger) http.HandlerFunc {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.Method != http.MethodConnect {
|
|
||||||
http.Error(w, "CONNECT only", http.StatusMethodNotAllowed)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
targetAddr := strings.TrimSpace(r.Host)
|
|
||||||
if targetAddr == "" {
|
|
||||||
targetAddr = strings.TrimSpace(r.URL.Host)
|
|
||||||
}
|
|
||||||
if targetAddr == "" {
|
|
||||||
http.Error(w, "missing target host", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if _, _, err := net.SplitHostPort(targetAddr); err != nil {
|
|
||||||
targetAddr = net.JoinHostPort(targetAddr, "443")
|
|
||||||
}
|
|
||||||
|
|
||||||
startedAt := time.Now()
|
|
||||||
logger.Info("LS proxy bridge CONNECT", "upstream", b.upstream, "target", targetAddr)
|
|
||||||
|
|
||||||
dialCtx, cancel := context.WithTimeout(r.Context(), lsProxyBridgeDialTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
targetConn, err := dialViaProxy(dialCtx, dialer, targetAddr)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn("LS proxy bridge dial failed",
|
|
||||||
"upstream", b.upstream,
|
|
||||||
"target", targetAddr,
|
|
||||||
"elapsed", time.Since(startedAt).Truncate(time.Millisecond),
|
|
||||||
"err", err)
|
|
||||||
http.Error(w, "proxy dial failed", http.StatusBadGateway)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
logger.Info("LS proxy bridge CONNECT established",
|
|
||||||
"upstream", b.upstream,
|
|
||||||
"target", targetAddr,
|
|
||||||
"elapsed", time.Since(startedAt).Truncate(time.Millisecond))
|
|
||||||
|
|
||||||
hijacker, ok := w.(http.Hijacker)
|
|
||||||
if !ok {
|
|
||||||
_ = targetConn.Close()
|
|
||||||
http.Error(w, "hijack unsupported", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
clientConn, rw, err := hijacker.Hijack()
|
|
||||||
if err != nil {
|
|
||||||
_ = targetConn.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")); err != nil {
|
|
||||||
_ = targetConn.Close()
|
|
||||||
_ = clientConn.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if rw != nil && rw.Reader.Buffered() > 0 {
|
|
||||||
if _, err := io.CopyN(targetConn, rw, int64(rw.Reader.Buffered())); err != nil {
|
|
||||||
_ = targetConn.Close()
|
|
||||||
_ = clientConn.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
tunnelConns(clientConn, targetConn)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func dialViaProxy(ctx context.Context, dialer proxy.Dialer, targetAddr string) (net.Conn, error) {
|
|
||||||
if contextDialer, ok := dialer.(proxy.ContextDialer); ok {
|
|
||||||
return contextDialer.DialContext(ctx, "tcp", targetAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
type dialResult struct {
|
|
||||||
conn net.Conn
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
resultCh := make(chan dialResult, 1)
|
|
||||||
go func() {
|
|
||||||
conn, err := dialer.Dial("tcp", targetAddr)
|
|
||||||
resultCh <- dialResult{conn: conn, err: err}
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, ctx.Err()
|
|
||||||
case result := <-resultCh:
|
|
||||||
return result.conn, result.err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *lsProxyBridge) probeConnectivity(dialer proxy.Dialer, logger *slog.Logger) {
|
|
||||||
for _, targetAddr := range lsProxyBridgeProbeTargets {
|
|
||||||
startedAt := time.Now()
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), lsProxyBridgeDialTimeout)
|
|
||||||
conn, err := dialViaProxy(ctx, dialer, targetAddr)
|
|
||||||
cancel()
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn("LS proxy bridge probe failed",
|
|
||||||
"upstream", b.upstream,
|
|
||||||
"target", targetAddr,
|
|
||||||
"elapsed", time.Since(startedAt).Truncate(time.Millisecond),
|
|
||||||
"err", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
_ = conn.Close()
|
|
||||||
logger.Info("LS proxy bridge probe succeeded",
|
|
||||||
"upstream", b.upstream,
|
|
||||||
"target", targetAddr,
|
|
||||||
"elapsed", time.Since(startedAt).Truncate(time.Millisecond))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func tunnelConns(clientConn net.Conn, targetConn net.Conn) {
|
|
||||||
var once sync.Once
|
|
||||||
closeBoth := func() {
|
|
||||||
_ = clientConn.Close()
|
|
||||||
_ = targetConn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
_, _ = io.Copy(targetConn, clientConn)
|
|
||||||
once.Do(closeBoth)
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
_, _ = io.Copy(clientConn, targetConn)
|
|
||||||
once.Do(closeBoth)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func readConnectResponse(br *bufio.Reader) (*http.Response, error) {
|
|
||||||
return http.ReadResponse(br, &http.Request{Method: http.MethodConnect})
|
|
||||||
}
|
|
||||||
@ -1,193 +0,0 @@
|
|||||||
package lspool
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestPrepareLSProxyURLPassesThroughHTTPProxy(t *testing.T) {
|
|
||||||
t.Cleanup(closeAllLSProxyBridgesForTest)
|
|
||||||
|
|
||||||
got, err := prepareLSProxyURL("http://proxy.example.com:8080")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, "http://proxy.example.com:8080", got)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPrepareLSProxyURLBridgesSOCKS5ForLS(t *testing.T) {
|
|
||||||
t.Cleanup(closeAllLSProxyBridgesForTest)
|
|
||||||
|
|
||||||
targetAddr, closeTarget := startBridgeEchoServer(t)
|
|
||||||
defer closeTarget()
|
|
||||||
|
|
||||||
socksURL, closeSOCKS := startBridgeSOCKS5Server(t)
|
|
||||||
defer closeSOCKS()
|
|
||||||
|
|
||||||
bridgeURL, err := prepareLSProxyURL(socksURL)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.True(t, strings.HasPrefix(bridgeURL, "http://127.0.0.1:"))
|
|
||||||
|
|
||||||
// Same SOCKS upstream should reuse the same local bridge.
|
|
||||||
reusedURL, err := prepareLSProxyURL(socksURL)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, bridgeURL, reusedURL)
|
|
||||||
|
|
||||||
bridgeAddr := strings.TrimPrefix(bridgeURL, "http://")
|
|
||||||
conn, err := net.Dial("tcp", bridgeAddr)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
_, err = fmt.Fprintf(conn, "CONNECT %s HTTP/1.1\r\nHost: %s\r\n\r\n", targetAddr, targetAddr)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
reader := bufio.NewReader(conn)
|
|
||||||
resp, err := readConnectResponse(reader)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, 200, resp.StatusCode)
|
|
||||||
|
|
||||||
_, err = conn.Write([]byte("ping"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
reply := make([]byte, 4)
|
|
||||||
_, err = io.ReadFull(reader, reply)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, "pong", string(reply))
|
|
||||||
}
|
|
||||||
|
|
||||||
func startBridgeEchoServer(t *testing.T) (string, func()) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer close(done)
|
|
||||||
for {
|
|
||||||
conn, err := ln.Accept()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
go func(c net.Conn) {
|
|
||||||
defer c.Close()
|
|
||||||
buf := make([]byte, 4)
|
|
||||||
if _, err := io.ReadFull(c, buf); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if string(buf) == "ping" {
|
|
||||||
_, _ = c.Write([]byte("pong"))
|
|
||||||
}
|
|
||||||
}(conn)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return ln.Addr().String(), func() {
|
|
||||||
_ = ln.Close()
|
|
||||||
<-done
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func startBridgeSOCKS5Server(t *testing.T) (string, func()) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer close(done)
|
|
||||||
for {
|
|
||||||
conn, err := ln.Accept()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
go handleBridgeSOCKS5Conn(conn)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return "socks5://" + ln.Addr().String(), func() {
|
|
||||||
_ = ln.Close()
|
|
||||||
<-done
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleBridgeSOCKS5Conn(conn net.Conn) {
|
|
||||||
header := make([]byte, 2)
|
|
||||||
if _, err := io.ReadFull(conn, header); err != nil {
|
|
||||||
_ = conn.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
methods := make([]byte, int(header[1]))
|
|
||||||
if _, err := io.ReadFull(conn, methods); err != nil {
|
|
||||||
_ = conn.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_, _ = conn.Write([]byte{0x05, 0x00})
|
|
||||||
|
|
||||||
reqHeader := make([]byte, 4)
|
|
||||||
if _, err := io.ReadFull(conn, reqHeader); err != nil {
|
|
||||||
_ = conn.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if reqHeader[0] != 0x05 || reqHeader[1] != 0x01 {
|
|
||||||
_ = conn.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
targetHost, ok := readSOCKS5Addr(conn, reqHeader[3])
|
|
||||||
if !ok {
|
|
||||||
_ = conn.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
portBuf := make([]byte, 2)
|
|
||||||
if _, err := io.ReadFull(conn, portBuf); err != nil {
|
|
||||||
_ = conn.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
targetAddr := fmt.Sprintf("%s:%d", targetHost, binary.BigEndian.Uint16(portBuf))
|
|
||||||
|
|
||||||
targetConn, err := net.Dial("tcp", targetAddr)
|
|
||||||
if err != nil {
|
|
||||||
_, _ = conn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
|
|
||||||
_ = conn.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _ = conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
|
|
||||||
tunnelConns(conn, targetConn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func readSOCKS5Addr(conn net.Conn, atyp byte) (string, bool) {
|
|
||||||
switch atyp {
|
|
||||||
case 0x01:
|
|
||||||
buf := make([]byte, 4)
|
|
||||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
return net.IP(buf).String(), true
|
|
||||||
case 0x03:
|
|
||||||
lenBuf := make([]byte, 1)
|
|
||||||
if _, err := io.ReadFull(conn, lenBuf); err != nil {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
buf := make([]byte, int(lenBuf[0]))
|
|
||||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
return string(buf), true
|
|
||||||
case 0x04:
|
|
||||||
buf := make([]byte, 16)
|
|
||||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
return net.IP(buf).String(), true
|
|
||||||
default:
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,138 +0,0 @@
|
|||||||
package lspool
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
|
|
||||||
)
|
|
||||||
|
|
||||||
type lsLaunchPlan struct {
|
|
||||||
cmd *exec.Cmd
|
|
||||||
effectiveProxyURL string
|
|
||||||
proxyMode string
|
|
||||||
cleanup func()
|
|
||||||
}
|
|
||||||
|
|
||||||
func prepareLSLaunchPlan(binPath string, args []string, rawProxyURL string) (*lsLaunchPlan, error) {
|
|
||||||
normalized, parsed, err := proxyurl.Parse(rawProxyURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
plan := &lsLaunchPlan{
|
|
||||||
cmd: exec.Command(binPath, args...),
|
|
||||||
proxyMode: "direct",
|
|
||||||
}
|
|
||||||
|
|
||||||
if parsed == nil {
|
|
||||||
return plan, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
switch strings.ToLower(parsed.Scheme) {
|
|
||||||
case "http", "https":
|
|
||||||
plan.effectiveProxyURL = normalized
|
|
||||||
plan.proxyMode = "env-http-proxy"
|
|
||||||
return plan, nil
|
|
||||||
|
|
||||||
case "socks5", "socks5h":
|
|
||||||
if proxychainsPath, err := exec.LookPath("proxychains4"); err == nil {
|
|
||||||
cfgPath, err := writeProxychainsConfig(parsed)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
plan.cmd = exec.Command(proxychainsPath, append([]string{"-f", cfgPath, binPath}, args...)...)
|
|
||||||
plan.proxyMode = "proxychains4"
|
|
||||||
plan.cleanup = func() {
|
|
||||||
_ = os.Remove(cfgPath)
|
|
||||||
}
|
|
||||||
return plan, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
effectiveProxyURL, err := prepareLSProxyURL(normalized)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
plan.effectiveProxyURL = effectiveProxyURL
|
|
||||||
plan.proxyMode = "http-connect-bridge"
|
|
||||||
return plan, nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unsupported LS proxy scheme: %s", parsed.Scheme)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeProxychainsConfig(proxyURL *url.URL) (string, error) {
|
|
||||||
content, err := buildProxychainsConfig(proxyURL)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
file, err := os.CreateTemp("", "sub2api-proxychains-*.conf")
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("create proxychains config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := file.WriteString(content); err != nil {
|
|
||||||
_ = file.Close()
|
|
||||||
_ = os.Remove(file.Name())
|
|
||||||
return "", fmt.Errorf("write proxychains config: %w", err)
|
|
||||||
}
|
|
||||||
if err := file.Close(); err != nil {
|
|
||||||
_ = os.Remove(file.Name())
|
|
||||||
return "", fmt.Errorf("close proxychains config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return file.Name(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildProxychainsConfig(proxyURL *url.URL) (string, error) {
|
|
||||||
if proxyURL == nil {
|
|
||||||
return "", fmt.Errorf("proxy url is nil")
|
|
||||||
}
|
|
||||||
if scheme := strings.ToLower(proxyURL.Scheme); scheme != "socks5" && scheme != "socks5h" {
|
|
||||||
return "", fmt.Errorf("proxychains only supports socks5/socks5h, got %s", proxyURL.Scheme)
|
|
||||||
}
|
|
||||||
|
|
||||||
host := strings.TrimSpace(proxyURL.Hostname())
|
|
||||||
port := strings.TrimSpace(proxyURL.Port())
|
|
||||||
if host == "" {
|
|
||||||
return "", fmt.Errorf("proxy host is empty")
|
|
||||||
}
|
|
||||||
if port == "" {
|
|
||||||
port = "1080"
|
|
||||||
}
|
|
||||||
|
|
||||||
username := proxyURL.User.Username()
|
|
||||||
password, _ := proxyURL.User.Password()
|
|
||||||
if strings.ContainsAny(username, " \t\r\n") || strings.ContainsAny(password, " \t\r\n") {
|
|
||||||
return "", fmt.Errorf("proxychains credentials cannot contain whitespace")
|
|
||||||
}
|
|
||||||
|
|
||||||
var builder strings.Builder
|
|
||||||
builder.WriteString("strict_chain\n")
|
|
||||||
builder.WriteString("proxy_dns\n")
|
|
||||||
builder.WriteString("remote_dns_subnet 224\n")
|
|
||||||
builder.WriteString("tcp_connect_time_out 8000\n")
|
|
||||||
builder.WriteString("tcp_read_time_out 15000\n")
|
|
||||||
builder.WriteString("localnet 127.0.0.0/255.0.0.0\n")
|
|
||||||
builder.WriteString("localnet ::1/128\n")
|
|
||||||
builder.WriteString("[ProxyList]\n")
|
|
||||||
builder.WriteString("socks5 ")
|
|
||||||
builder.WriteString(host)
|
|
||||||
builder.WriteString(" ")
|
|
||||||
builder.WriteString(port)
|
|
||||||
if username != "" {
|
|
||||||
builder.WriteString(" ")
|
|
||||||
builder.WriteString(username)
|
|
||||||
if password != "" {
|
|
||||||
builder.WriteString(" ")
|
|
||||||
builder.WriteString(password)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
builder.WriteString("\n")
|
|
||||||
return builder.String(), nil
|
|
||||||
}
|
|
||||||
@ -1,31 +0,0 @@
|
|||||||
package lspool
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestBuildProxychainsConfigIncludesAuthAndLocalBypass(t *testing.T) {
|
|
||||||
proxyURL, err := url.Parse("socks5h://testuser:testpass@192.0.2.1:1080")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
cfg, err := buildProxychainsConfig(proxyURL)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Contains(t, cfg, "proxy_dns\n")
|
|
||||||
require.Contains(t, cfg, "localnet 127.0.0.0/255.0.0.0\n")
|
|
||||||
require.Contains(t, cfg, "localnet ::1/128\n")
|
|
||||||
require.Contains(t, cfg, "[ProxyList]\n")
|
|
||||||
require.Contains(t, cfg, "socks5 192.0.2.1 1080 testuser testpass\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildProxychainsConfigRejectsWhitespaceCredentials(t *testing.T) {
|
|
||||||
proxyURL, err := url.Parse("socks5h://user:bad%20pass@127.0.0.1:1080")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
_, err = buildProxychainsConfig(proxyURL)
|
|
||||||
require.Error(t, err)
|
|
||||||
require.True(t, strings.Contains(err.Error(), "whitespace"))
|
|
||||||
}
|
|
||||||
@ -1,99 +0,0 @@
|
|||||||
package lspool
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (i *Instance) callWorkerUnary(ctx context.Context, service, method, mode string, body []byte) ([]byte, error) {
|
|
||||||
endpoint, err := i.workerEndpoint("/rpc/unary", service, method, mode)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req.Header.Set("X-Worker-Token", i.workerToken)
|
|
||||||
if mode == "json" {
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
} else {
|
|
||||||
req.Header.Set("Content-Type", "application/octet-stream")
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := i.client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("worker rpc %s/%s: %w", service, method, err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
respBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("worker rpc read response: %w", err)
|
|
||||||
}
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return respBody, fmt.Errorf("worker rpc %s/%s HTTP %d: %s", service, method, resp.StatusCode, truncate(string(respBody), 200))
|
|
||||||
}
|
|
||||||
return respBody, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *Instance) callWorkerStream(ctx context.Context, service, method, mode string, body []byte) (*http.Response, error) {
|
|
||||||
endpoint, err := i.workerEndpoint("/rpc/stream", service, method, mode)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req.Header.Set("X-Worker-Token", i.workerToken)
|
|
||||||
if mode == "json" {
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
} else {
|
|
||||||
req.Header.Set("Content-Type", "application/octet-stream")
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := i.client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("worker stream rpc %s/%s: %w", service, method, err)
|
|
||||||
}
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
defer resp.Body.Close()
|
|
||||||
body, _ := io.ReadAll(resp.Body)
|
|
||||||
return nil, fmt.Errorf("worker stream rpc %s/%s HTTP %d: %s", service, method, resp.StatusCode, truncate(string(body), 200))
|
|
||||||
}
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *Instance) workerEndpoint(path, service, method, mode string) (string, error) {
|
|
||||||
base := url.URL{
|
|
||||||
Scheme: "http",
|
|
||||||
Host: i.Address,
|
|
||||||
Path: path,
|
|
||||||
}
|
|
||||||
values := url.Values{}
|
|
||||||
values.Set("service", service)
|
|
||||||
values.Set("method", method)
|
|
||||||
values.Set("mode", mode)
|
|
||||||
if i.routingKey != "" {
|
|
||||||
values.Set("routing_key", i.routingKey)
|
|
||||||
}
|
|
||||||
base.RawQuery = values.Encode()
|
|
||||||
return base.String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func marshalWorkerJSONBody(input any) ([]byte, error) {
|
|
||||||
if input == nil {
|
|
||||||
return []byte("{}"), nil
|
|
||||||
}
|
|
||||||
body, err := json.Marshal(input)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return body, nil
|
|
||||||
}
|
|
||||||
File diff suppressed because it is too large
Load Diff
@ -1,680 +0,0 @@
|
|||||||
package lspool
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"log/slog"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
|
|
||||||
"github.com/docker/docker/api/types/container"
|
|
||||||
"github.com/docker/docker/api/types/filters"
|
|
||||||
"github.com/docker/docker/api/types/network"
|
|
||||||
"github.com/docker/docker/client"
|
|
||||||
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
lsWorkerManagedByLabel = "managed-by"
|
|
||||||
lsWorkerManagedByValue = "sub2api"
|
|
||||||
lsWorkerAccountLabel = "account_id"
|
|
||||||
lsWorkerProxyHashLabel = "proxy_hash"
|
|
||||||
lsWorkerImageTagLabel = "image_tag"
|
|
||||||
lsWorkerControlPort = 18081
|
|
||||||
)
|
|
||||||
|
|
||||||
type workerManagerConfig struct {
|
|
||||||
Image string
|
|
||||||
Network string
|
|
||||||
DockerSocket string
|
|
||||||
IdleTTL time.Duration
|
|
||||||
MaxActive int
|
|
||||||
StartupTimeout time.Duration
|
|
||||||
RequestTimeout time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
type dockerClient interface {
|
|
||||||
ContainerList(ctx context.Context, options container.ListOptions) ([]container.Summary, error)
|
|
||||||
ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *ocispec.Platform, containerName string) (container.CreateResponse, error)
|
|
||||||
ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error
|
|
||||||
ContainerInspect(ctx context.Context, containerID string) (container.InspectResponse, error)
|
|
||||||
ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error
|
|
||||||
ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error
|
|
||||||
Close() error
|
|
||||||
}
|
|
||||||
|
|
||||||
type workerManager struct {
|
|
||||||
cfg workerManagerConfig
|
|
||||||
docker dockerClient
|
|
||||||
http *http.Client
|
|
||||||
|
|
||||||
mu sync.Mutex
|
|
||||||
workers map[string]*workerHandle
|
|
||||||
state map[string]*workerAccountState
|
|
||||||
|
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
logger *slog.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
type workerHandle struct {
|
|
||||||
Key string
|
|
||||||
AccountID string
|
|
||||||
ProxyURL string
|
|
||||||
ProxyHash string
|
|
||||||
ContainerID string
|
|
||||||
Container string
|
|
||||||
Address string
|
|
||||||
AuthToken string
|
|
||||||
LastUsed time.Time
|
|
||||||
LastStateSHA string
|
|
||||||
}
|
|
||||||
|
|
||||||
type workerAccountState struct {
|
|
||||||
HasToken bool `json:"has_token"`
|
|
||||||
AccessToken string `json:"access_token,omitempty"`
|
|
||||||
RefreshToken string `json:"refresh_token,omitempty"`
|
|
||||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
|
||||||
HasModelCredits bool `json:"has_model_credits"`
|
|
||||||
UseAICredits bool `json:"use_ai_credits"`
|
|
||||||
AvailableCredits *int32 `json:"available_credits,omitempty"`
|
|
||||||
MinimumCreditAmount *int32 `json:"minimum_credit_amount,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewWorkerManagerFromConfig(cfg *config.Config) (Backend, error) {
|
|
||||||
if cfg == nil {
|
|
||||||
return nil, fmt.Errorf("config is nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
managerCfg := workerManagerConfig{
|
|
||||||
Image: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.Image),
|
|
||||||
Network: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.Network),
|
|
||||||
DockerSocket: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.DockerSocket),
|
|
||||||
IdleTTL: cfg.Gateway.AntigravityLSWorker.IdleTTL,
|
|
||||||
MaxActive: cfg.Gateway.AntigravityLSWorker.MaxActive,
|
|
||||||
StartupTimeout: cfg.Gateway.AntigravityLSWorker.StartupTimeout,
|
|
||||||
RequestTimeout: cfg.Gateway.AntigravityLSWorker.RequestTimeout,
|
|
||||||
}
|
|
||||||
|
|
||||||
if managerCfg.Image == "" {
|
|
||||||
managerCfg.Image = "weishaw/sub2api-lsworker:latest"
|
|
||||||
}
|
|
||||||
if managerCfg.Network == "" {
|
|
||||||
managerCfg.Network = "sub2api-network"
|
|
||||||
}
|
|
||||||
if managerCfg.DockerSocket == "" {
|
|
||||||
managerCfg.DockerSocket = "unix:///var/run/docker.sock"
|
|
||||||
}
|
|
||||||
if managerCfg.IdleTTL <= 0 {
|
|
||||||
managerCfg.IdleTTL = 15 * time.Minute
|
|
||||||
}
|
|
||||||
if managerCfg.MaxActive < 1 {
|
|
||||||
managerCfg.MaxActive = 50
|
|
||||||
}
|
|
||||||
if managerCfg.StartupTimeout <= 0 {
|
|
||||||
managerCfg.StartupTimeout = 45 * time.Second
|
|
||||||
}
|
|
||||||
if managerCfg.RequestTimeout <= 0 {
|
|
||||||
managerCfg.RequestTimeout = 60 * time.Second
|
|
||||||
}
|
|
||||||
|
|
||||||
opts := []client.Opt{client.WithAPIVersionNegotiation()}
|
|
||||||
if managerCfg.DockerSocket != "" {
|
|
||||||
opts = append(opts, client.WithHost(managerCfg.DockerSocket))
|
|
||||||
} else {
|
|
||||||
opts = append(opts, client.FromEnv)
|
|
||||||
}
|
|
||||||
|
|
||||||
dockerClient, err := client.NewClientWithOpts(opts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("create docker client: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return newWorkerManager(managerCfg, dockerClient)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newWorkerManager(cfg workerManagerConfig, docker dockerClient) (*workerManager, error) {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
mgr := &workerManager{
|
|
||||||
cfg: cfg,
|
|
||||||
docker: docker,
|
|
||||||
http: &http.Client{
|
|
||||||
Timeout: cfg.RequestTimeout,
|
|
||||||
Transport: &http.Transport{
|
|
||||||
Proxy: nil,
|
|
||||||
DialContext: (&net.Dialer{
|
|
||||||
Timeout: 5 * time.Second,
|
|
||||||
KeepAlive: 30 * time.Second,
|
|
||||||
}).DialContext,
|
|
||||||
MaxIdleConnsPerHost: 8,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
workers: make(map[string]*workerHandle),
|
|
||||||
state: make(map[string]*workerAccountState),
|
|
||||||
ctx: ctx,
|
|
||||||
cancel: cancel,
|
|
||||||
logger: slog.Default().With("component", "lspool-worker-manager"),
|
|
||||||
}
|
|
||||||
if err := mgr.reconcileManagedContainers(ctx); err != nil {
|
|
||||||
cancel()
|
|
||||||
_ = docker.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
go mgr.cleanupLoop()
|
|
||||||
return mgr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *workerManager) Close() {
|
|
||||||
m.cancel()
|
|
||||||
|
|
||||||
m.mu.Lock()
|
|
||||||
workers := make([]*workerHandle, 0, len(m.workers))
|
|
||||||
for _, handle := range m.workers {
|
|
||||||
workers = append(workers, handle)
|
|
||||||
}
|
|
||||||
m.workers = make(map[string]*workerHandle)
|
|
||||||
m.mu.Unlock()
|
|
||||||
|
|
||||||
for _, handle := range workers {
|
|
||||||
m.removeWorkerContainer(context.Background(), handle)
|
|
||||||
}
|
|
||||||
if m.docker != nil {
|
|
||||||
_ = m.docker.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *workerManager) Stats() map[string]any {
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
return map[string]any{
|
|
||||||
"accounts": len(m.state),
|
|
||||||
"total": len(m.workers),
|
|
||||||
"active": len(m.workers),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *workerManager) SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time) {
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
state := m.ensureStateLocked(accountID)
|
|
||||||
state.HasToken = true
|
|
||||||
state.AccessToken = accessToken
|
|
||||||
state.RefreshToken = refreshToken
|
|
||||||
if expiresAt.IsZero() {
|
|
||||||
state.ExpiresAt = nil
|
|
||||||
} else {
|
|
||||||
ts := expiresAt.UTC()
|
|
||||||
state.ExpiresAt = &ts
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *workerManager) SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32) {
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
state := m.ensureStateLocked(accountID)
|
|
||||||
state.HasModelCredits = true
|
|
||||||
state.UseAICredits = useAICredits
|
|
||||||
state.AvailableCredits = cloneInt32Ptr(availableCredits)
|
|
||||||
state.MinimumCreditAmount = cloneInt32Ptr(minimumCreditAmountForUsage)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *workerManager) GetOrCreate(accountID, routingKey string, proxyURL ...string) (*Instance, error) {
|
|
||||||
rawProxy := ""
|
|
||||||
if len(proxyURL) > 0 {
|
|
||||||
rawProxy = proxyURL[0]
|
|
||||||
}
|
|
||||||
normalizedProxy, parsedProxy, err := resolveWorkerProxy(rawProxy)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if parsedProxy == nil {
|
|
||||||
return nil, fmt.Errorf("ls worker requires a socks5/socks5h proxy for account %s", accountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
replica := replicaSlotIndex(routingKey, parseLSReplicaCount())
|
|
||||||
proxyHash := proxyHash(normalizedProxy)
|
|
||||||
workerKey := buildWorkerKey(accountID, proxyHash)
|
|
||||||
|
|
||||||
m.mu.Lock()
|
|
||||||
state := cloneWorkerAccountState(m.state[accountID])
|
|
||||||
if state == nil || !state.HasToken || strings.TrimSpace(state.AccessToken) == "" {
|
|
||||||
m.mu.Unlock()
|
|
||||||
return nil, fmt.Errorf("ls worker missing access token for account %s", accountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
handle := m.workers[workerKey]
|
|
||||||
if handle == nil {
|
|
||||||
if len(m.workers) >= m.cfg.MaxActive {
|
|
||||||
m.mu.Unlock()
|
|
||||||
return nil, fmt.Errorf("ls worker limit reached (%d active)", m.cfg.MaxActive)
|
|
||||||
}
|
|
||||||
handle, err = m.createWorkerLocked(accountID, normalizedProxy, proxyHash, parsedProxy)
|
|
||||||
if err != nil {
|
|
||||||
m.mu.Unlock()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
m.workers[workerKey] = handle
|
|
||||||
}
|
|
||||||
handle.LastUsed = time.Now()
|
|
||||||
m.mu.Unlock()
|
|
||||||
|
|
||||||
if err := m.waitForWorkerHealthy(handle); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := m.syncWorkerState(handle, state); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := m.waitForWorkerReady(handle, routingKey); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
inst := &Instance{
|
|
||||||
AccountID: accountID,
|
|
||||||
Replica: replica,
|
|
||||||
Address: handle.Address,
|
|
||||||
client: m.http,
|
|
||||||
healthy: true,
|
|
||||||
lastUsed: time.Now(),
|
|
||||||
modelMapReady: 1,
|
|
||||||
remote: true,
|
|
||||||
workerToken: handle.AuthToken,
|
|
||||||
routingKey: routingKey,
|
|
||||||
}
|
|
||||||
return inst, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *workerManager) cleanupLoop() {
|
|
||||||
ticker := time.NewTicker(time.Minute)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-m.ctx.Done():
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
m.collectIdleWorkers()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *workerManager) collectIdleWorkers() {
|
|
||||||
now := time.Now()
|
|
||||||
var expired []*workerHandle
|
|
||||||
|
|
||||||
m.mu.Lock()
|
|
||||||
for key, handle := range m.workers {
|
|
||||||
if handle == nil {
|
|
||||||
delete(m.workers, key)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if now.Sub(handle.LastUsed) <= m.cfg.IdleTTL {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
expired = append(expired, handle)
|
|
||||||
delete(m.workers, key)
|
|
||||||
}
|
|
||||||
m.mu.Unlock()
|
|
||||||
|
|
||||||
for _, handle := range expired {
|
|
||||||
m.removeWorkerContainer(context.Background(), handle)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *workerManager) reconcileManagedContainers(ctx context.Context) error {
|
|
||||||
args := filters.NewArgs()
|
|
||||||
args.Add("label", fmt.Sprintf("%s=%s", lsWorkerManagedByLabel, lsWorkerManagedByValue))
|
|
||||||
|
|
||||||
containers, err := m.docker.ContainerList(ctx, container.ListOptions{
|
|
||||||
All: true,
|
|
||||||
Filters: args,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("list managed ls workers: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, summary := range containers {
|
|
||||||
handle := &workerHandle{
|
|
||||||
ContainerID: summary.ID,
|
|
||||||
Container: strings.TrimPrefix(firstContainerName(summary.Names), "/"),
|
|
||||||
}
|
|
||||||
if err := m.removeWorkerContainer(ctx, handle); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *workerManager) createWorkerLocked(accountID, proxyURL, proxyHash string, parsedProxy *url.URL) (*workerHandle, error) {
|
|
||||||
containerName := fmt.Sprintf("sub2api-ls-%s-%s", accountID, proxyHash[:8])
|
|
||||||
authToken := generateUUID()
|
|
||||||
|
|
||||||
proxyHost := parsedProxy.Hostname()
|
|
||||||
proxyPort := parsedProxy.Port()
|
|
||||||
if proxyPort == "" {
|
|
||||||
proxyPort = "1080"
|
|
||||||
}
|
|
||||||
proxyUser := parsedProxy.User.Username()
|
|
||||||
proxyPass, _ := parsedProxy.User.Password()
|
|
||||||
|
|
||||||
labels := map[string]string{
|
|
||||||
lsWorkerManagedByLabel: lsWorkerManagedByValue,
|
|
||||||
lsWorkerAccountLabel: accountID,
|
|
||||||
lsWorkerProxyHashLabel: proxyHash,
|
|
||||||
lsWorkerImageTagLabel: m.cfg.Image,
|
|
||||||
}
|
|
||||||
|
|
||||||
env := []string{
|
|
||||||
"ANTIGRAVITY_APP_ROOT=/app/ls",
|
|
||||||
fmt.Sprintf("LSWORKER_ACCOUNT_ID=%s", accountID),
|
|
||||||
fmt.Sprintf("LSWORKER_AUTH_TOKEN=%s", authToken),
|
|
||||||
fmt.Sprintf("LSWORKER_LISTEN_ADDR=0.0.0.0:%d", lsWorkerControlPort),
|
|
||||||
fmt.Sprintf("LSWORKER_NETWORK_READY_FILE=%s", "/run/lsworker/network-ready"),
|
|
||||||
fmt.Sprintf("LSWORKER_PROXY_URL=%s", proxyURL),
|
|
||||||
fmt.Sprintf("LSWORKER_PROXY_HOST=%s", proxyHost),
|
|
||||||
fmt.Sprintf("LSWORKER_PROXY_PORT=%s", proxyPort),
|
|
||||||
fmt.Sprintf("LSWORKER_PROXY_USER=%s", proxyUser),
|
|
||||||
fmt.Sprintf("LSWORKER_PROXY_PASS=%s", proxyPass),
|
|
||||||
fmt.Sprintf("LSWORKER_CONTROL_PORT=%d", lsWorkerControlPort),
|
|
||||||
fmt.Sprintf("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT=%d", parseLSReplicaCount()),
|
|
||||||
}
|
|
||||||
if tz := strings.TrimSpace(os.Getenv("TZ")); tz != "" {
|
|
||||||
env = append(env, "TZ="+tz)
|
|
||||||
}
|
|
||||||
|
|
||||||
createResp, err := m.docker.ContainerCreate(m.ctx, &container.Config{
|
|
||||||
Image: m.cfg.Image,
|
|
||||||
Labels: labels,
|
|
||||||
Env: env,
|
|
||||||
}, &container.HostConfig{
|
|
||||||
CapAdd: []string{"NET_ADMIN"},
|
|
||||||
}, &network.NetworkingConfig{
|
|
||||||
EndpointsConfig: map[string]*network.EndpointSettings{
|
|
||||||
m.cfg.Network: {},
|
|
||||||
},
|
|
||||||
}, nil, containerName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("create ls worker container: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.docker.ContainerStart(m.ctx, createResp.ID, container.StartOptions{}); err != nil {
|
|
||||||
_ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true})
|
|
||||||
return nil, fmt.Errorf("start ls worker container: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
inspect, err := m.docker.ContainerInspect(m.ctx, createResp.ID)
|
|
||||||
if err != nil {
|
|
||||||
_ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true})
|
|
||||||
return nil, fmt.Errorf("inspect ls worker container: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
address, err := workerAddressFromInspect(inspect, m.cfg.Network)
|
|
||||||
if err != nil {
|
|
||||||
_ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true})
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
m.logger.Info("created ls worker",
|
|
||||||
"account", shortAccountID(accountID),
|
|
||||||
"container", containerName,
|
|
||||||
"address", address,
|
|
||||||
"proxy_hash", proxyHash[:8])
|
|
||||||
|
|
||||||
return &workerHandle{
|
|
||||||
Key: buildWorkerKey(accountID, proxyHash),
|
|
||||||
AccountID: accountID,
|
|
||||||
ProxyURL: proxyURL,
|
|
||||||
ProxyHash: proxyHash,
|
|
||||||
ContainerID: createResp.ID,
|
|
||||||
Container: containerName,
|
|
||||||
Address: address,
|
|
||||||
AuthToken: authToken,
|
|
||||||
LastUsed: time.Now(),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func workerAddressFromInspect(inspect container.InspectResponse, networkName string) (string, error) {
|
|
||||||
if inspect.NetworkSettings == nil {
|
|
||||||
return "", fmt.Errorf("ls worker inspect missing network settings")
|
|
||||||
}
|
|
||||||
if endpoint, ok := inspect.NetworkSettings.Networks[networkName]; ok && endpoint != nil && strings.TrimSpace(endpoint.IPAddress) != "" {
|
|
||||||
return net.JoinHostPort(strings.TrimSpace(endpoint.IPAddress), strconv.Itoa(lsWorkerControlPort)), nil
|
|
||||||
}
|
|
||||||
for _, endpoint := range inspect.NetworkSettings.Networks {
|
|
||||||
if endpoint != nil && strings.TrimSpace(endpoint.IPAddress) != "" {
|
|
||||||
return net.JoinHostPort(strings.TrimSpace(endpoint.IPAddress), strconv.Itoa(lsWorkerControlPort)), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("ls worker missing IP address on network %s", networkName)
|
|
||||||
}
|
|
||||||
|
|
||||||
func firstContainerName(names []string) string {
|
|
||||||
if len(names) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return names[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *workerManager) waitForWorkerHealthy(handle *workerHandle) error {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), m.cfg.StartupTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
for {
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, workerURL(handle, "/healthz", nil), nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
req.Header.Set("X-Worker-Token", handle.AuthToken)
|
|
||||||
resp, err := m.http.Do(req)
|
|
||||||
if err == nil {
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
if resp.StatusCode == http.StatusOK {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return fmt.Errorf("worker %s failed health check: %w", handle.Container, ctx.Err())
|
|
||||||
case <-time.After(500 * time.Millisecond):
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *workerManager) waitForWorkerReady(handle *workerHandle, routingKey string) error {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), m.cfg.StartupTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
values := url.Values{}
|
|
||||||
if strings.TrimSpace(routingKey) != "" {
|
|
||||||
values.Set("routing_key", routingKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
lastStatus int
|
|
||||||
lastBody string
|
|
||||||
)
|
|
||||||
|
|
||||||
for {
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, workerURL(handle, "/readyz", values), nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
req.Header.Set("X-Worker-Token", handle.AuthToken)
|
|
||||||
resp, err := m.http.Do(req)
|
|
||||||
if err == nil {
|
|
||||||
body, _ := io.ReadAll(resp.Body)
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
lastStatus = resp.StatusCode
|
|
||||||
lastBody = truncate(string(body), 200)
|
|
||||||
if resp.StatusCode == http.StatusOK {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if isWorkerModelMappingUnavailable(resp.StatusCode, lastBody) {
|
|
||||||
return fmt.Errorf("%w: worker %s %s", errLSModelMapDenied, handle.Container, strings.TrimSpace(lastBody))
|
|
||||||
}
|
|
||||||
if len(body) > 0 && shouldWarnWorkerNotReady(resp.StatusCode, lastBody) {
|
|
||||||
m.logger.Warn("ls worker not ready yet", "container", handle.Container, "status", resp.StatusCode, "body", truncate(string(body), 200))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
if lastStatus > 0 || lastBody != "" {
|
|
||||||
return fmt.Errorf("worker %s not ready for routing key %q (last_status=%d last_body=%q): %w", handle.Container, routingKey, lastStatus, lastBody, ctx.Err())
|
|
||||||
}
|
|
||||||
return fmt.Errorf("worker %s not ready for routing key %q: %w", handle.Container, routingKey, ctx.Err())
|
|
||||||
case <-time.After(500 * time.Millisecond):
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func shouldWarnWorkerNotReady(status int, body string) bool {
|
|
||||||
if status == http.StatusServiceUnavailable {
|
|
||||||
normalized := strings.ToLower(strings.TrimSpace(body))
|
|
||||||
if strings.Contains(normalized, "model mapping not ready") {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func isWorkerModelMappingUnavailable(status int, body string) bool {
|
|
||||||
if status != http.StatusServiceUnavailable {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
normalized := strings.ToLower(strings.TrimSpace(body))
|
|
||||||
return strings.Contains(normalized, errLSModelMapDenied.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *workerManager) syncWorkerState(handle *workerHandle, state *workerAccountState) error {
|
|
||||||
if state == nil {
|
|
||||||
return fmt.Errorf("ls worker state is nil")
|
|
||||||
}
|
|
||||||
body, err := json.Marshal(state)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("marshal worker state: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
sum := fmt.Sprintf("%x", sha256.Sum256(body))
|
|
||||||
if handle.LastStateSHA == sum {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), m.cfg.RequestTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, workerURL(handle, "/account/state", nil), bytes.NewReader(body))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("X-Worker-Token", handle.AuthToken)
|
|
||||||
|
|
||||||
resp, err := m.http.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("sync worker state: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
respBody, _ := io.ReadAll(resp.Body)
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return fmt.Errorf("sync worker state HTTP %d: %s", resp.StatusCode, truncate(string(respBody), 200))
|
|
||||||
}
|
|
||||||
handle.LastStateSHA = sum
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func workerURL(handle *workerHandle, path string, values url.Values) string {
|
|
||||||
base := url.URL{
|
|
||||||
Scheme: "http",
|
|
||||||
Host: handle.Address,
|
|
||||||
Path: path,
|
|
||||||
}
|
|
||||||
if values != nil {
|
|
||||||
base.RawQuery = values.Encode()
|
|
||||||
}
|
|
||||||
return base.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *workerManager) removeWorkerContainer(ctx context.Context, handle *workerHandle) error {
|
|
||||||
if handle == nil || strings.TrimSpace(handle.ContainerID) == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
timeout := 5
|
|
||||||
_ = m.docker.ContainerStop(ctx, handle.ContainerID, container.StopOptions{Timeout: &timeout})
|
|
||||||
if err := m.docker.ContainerRemove(ctx, handle.ContainerID, container.RemoveOptions{Force: true}); err != nil {
|
|
||||||
return fmt.Errorf("remove ls worker container %s: %w", handle.ContainerID, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *workerManager) ensureStateLocked(accountID string) *workerAccountState {
|
|
||||||
state := m.state[accountID]
|
|
||||||
if state == nil {
|
|
||||||
state = &workerAccountState{}
|
|
||||||
m.state[accountID] = state
|
|
||||||
}
|
|
||||||
return state
|
|
||||||
}
|
|
||||||
|
|
||||||
func resolveWorkerProxy(proxyURL string) (string, *url.URL, error) {
|
|
||||||
resolved := resolveLSProxy(proxyURL)
|
|
||||||
normalized, parsed, err := proxyurl.Parse(resolved)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
if parsed == nil {
|
|
||||||
return "", nil, nil
|
|
||||||
}
|
|
||||||
switch strings.ToLower(parsed.Scheme) {
|
|
||||||
case "socks5", "socks5h":
|
|
||||||
return normalized, parsed, nil
|
|
||||||
default:
|
|
||||||
return "", nil, fmt.Errorf("ls worker only supports socks5/socks5h proxies, got %s", parsed.Scheme)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func proxyHash(proxyURL string) string {
|
|
||||||
if strings.TrimSpace(proxyURL) == "" {
|
|
||||||
return "direct"
|
|
||||||
}
|
|
||||||
sum := sha256.Sum256([]byte(strings.TrimSpace(proxyURL)))
|
|
||||||
return fmt.Sprintf("%x", sum[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildWorkerKey(accountID, proxyHash string) string {
|
|
||||||
return accountID + ":" + proxyHash
|
|
||||||
}
|
|
||||||
|
|
||||||
func cloneInt32Ptr(v *int32) *int32 {
|
|
||||||
if v == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
cp := *v
|
|
||||||
return &cp
|
|
||||||
}
|
|
||||||
|
|
||||||
func cloneWorkerAccountState(state *workerAccountState) *workerAccountState {
|
|
||||||
if state == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
cp := *state
|
|
||||||
cp.AvailableCredits = cloneInt32Ptr(state.AvailableCredits)
|
|
||||||
cp.MinimumCreditAmount = cloneInt32Ptr(state.MinimumCreditAmount)
|
|
||||||
if state.ExpiresAt != nil {
|
|
||||||
ts := *state.ExpiresAt
|
|
||||||
cp.ExpiresAt = &ts
|
|
||||||
}
|
|
||||||
return &cp
|
|
||||||
}
|
|
||||||
@ -1,335 +0,0 @@
|
|||||||
package lspool
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/docker/docker/api/types/container"
|
|
||||||
"github.com/docker/docker/api/types/filters"
|
|
||||||
"github.com/docker/docker/api/types/network"
|
|
||||||
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
type fakeDockerClient struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
|
|
||||||
listResp []container.Summary
|
|
||||||
listCalls int
|
|
||||||
createCalls int
|
|
||||||
startCalls int
|
|
||||||
stopCalls int
|
|
||||||
removeCalls int
|
|
||||||
inspectCalls int
|
|
||||||
removedIDs []string
|
|
||||||
createdConfigs []*container.Config
|
|
||||||
inspectResp container.InspectResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeDockerClient) ContainerList(ctx context.Context, options container.ListOptions) ([]container.Summary, error) {
|
|
||||||
f.mu.Lock()
|
|
||||||
defer f.mu.Unlock()
|
|
||||||
f.listCalls++
|
|
||||||
return append([]container.Summary(nil), f.listResp...), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeDockerClient) ContainerCreate(ctx context.Context, cfg *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *ocispec.Platform, containerName string) (container.CreateResponse, error) {
|
|
||||||
f.mu.Lock()
|
|
||||||
defer f.mu.Unlock()
|
|
||||||
f.createCalls++
|
|
||||||
f.createdConfigs = append(f.createdConfigs, cfg)
|
|
||||||
return container.CreateResponse{ID: "worker-created"}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeDockerClient) ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error {
|
|
||||||
f.mu.Lock()
|
|
||||||
defer f.mu.Unlock()
|
|
||||||
f.startCalls++
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeDockerClient) ContainerInspect(ctx context.Context, containerID string) (container.InspectResponse, error) {
|
|
||||||
f.mu.Lock()
|
|
||||||
defer f.mu.Unlock()
|
|
||||||
f.inspectCalls++
|
|
||||||
return f.inspectResp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeDockerClient) ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error {
|
|
||||||
f.mu.Lock()
|
|
||||||
defer f.mu.Unlock()
|
|
||||||
f.stopCalls++
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeDockerClient) ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error {
|
|
||||||
f.mu.Lock()
|
|
||||||
defer f.mu.Unlock()
|
|
||||||
f.removeCalls++
|
|
||||||
f.removedIDs = append(f.removedIDs, containerID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeDockerClient) Close() error { return nil }
|
|
||||||
|
|
||||||
func TestResolveWorkerProxyRejectsHTTP(t *testing.T) {
|
|
||||||
_, _, err := resolveWorkerProxy("http://127.0.0.1:7890")
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Contains(t, err.Error(), "only supports socks5/socks5h")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxyHashUsesNormalizedProxy(t *testing.T) {
|
|
||||||
normalized, _, err := resolveWorkerProxy("socks5://user:pass@127.0.0.1:1080")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, "socks5h://user:pass@127.0.0.1:1080", normalized)
|
|
||||||
|
|
||||||
hash1 := proxyHash(normalized)
|
|
||||||
hash2 := proxyHash("socks5h://user:pass@127.0.0.1:1080")
|
|
||||||
require.Equal(t, hash1, hash2)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWorkerManagerRequiresToken(t *testing.T) {
|
|
||||||
fakeDocker := &fakeDockerClient{}
|
|
||||||
manager, err := newWorkerManager(workerManagerConfig{
|
|
||||||
Image: "worker:latest",
|
|
||||||
Network: "sub2api-network",
|
|
||||||
DockerSocket: "unix:///var/run/docker.sock",
|
|
||||||
IdleTTL: time.Minute,
|
|
||||||
MaxActive: 2,
|
|
||||||
StartupTimeout: time.Second,
|
|
||||||
RequestTimeout: time.Second,
|
|
||||||
}, fakeDocker)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer manager.Close()
|
|
||||||
|
|
||||||
_, err = manager.GetOrCreate("9", "rk-1", "socks5h://user:pass@127.0.0.1:1080")
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Contains(t, err.Error(), "missing access token")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWorkerManagerReusesExistingHandleAndDedupesStateSync(t *testing.T) {
|
|
||||||
var mu sync.Mutex
|
|
||||||
var healthCalls int
|
|
||||||
var readyCalls int
|
|
||||||
var stateCalls int
|
|
||||||
var stateBodies [][]byte
|
|
||||||
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
switch r.URL.Path {
|
|
||||||
case "/healthz":
|
|
||||||
mu.Lock()
|
|
||||||
healthCalls++
|
|
||||||
mu.Unlock()
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
_, _ = w.Write([]byte("ok"))
|
|
||||||
case "/readyz":
|
|
||||||
mu.Lock()
|
|
||||||
readyCalls++
|
|
||||||
mu.Unlock()
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
_, _ = w.Write([]byte("ready"))
|
|
||||||
case "/account/state":
|
|
||||||
body, _ := io.ReadAll(r.Body)
|
|
||||||
mu.Lock()
|
|
||||||
stateCalls++
|
|
||||||
stateBodies = append(stateBodies, body)
|
|
||||||
mu.Unlock()
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
_, _ = w.Write([]byte("ok"))
|
|
||||||
default:
|
|
||||||
http.NotFound(w, r)
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
fakeDocker := &fakeDockerClient{}
|
|
||||||
manager, err := newWorkerManager(workerManagerConfig{
|
|
||||||
Image: "worker:latest",
|
|
||||||
Network: "sub2api-network",
|
|
||||||
DockerSocket: "unix:///var/run/docker.sock",
|
|
||||||
IdleTTL: time.Minute,
|
|
||||||
MaxActive: 4,
|
|
||||||
StartupTimeout: time.Second,
|
|
||||||
RequestTimeout: time.Second,
|
|
||||||
}, fakeDocker)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer manager.Close()
|
|
||||||
|
|
||||||
accountID := "9"
|
|
||||||
proxyURL := "socks5h://user:pass@127.0.0.1:1080"
|
|
||||||
hash := proxyHash(proxyURL)
|
|
||||||
key := buildWorkerKey(accountID, hash)
|
|
||||||
|
|
||||||
manager.SetAccountToken(accountID, "ya29.test", "refresh", time.Now().Add(time.Hour))
|
|
||||||
manager.mu.Lock()
|
|
||||||
manager.workers[key] = &workerHandle{
|
|
||||||
Key: key,
|
|
||||||
AccountID: accountID,
|
|
||||||
ProxyURL: proxyURL,
|
|
||||||
ProxyHash: hash,
|
|
||||||
ContainerID: "existing-worker",
|
|
||||||
Container: "sub2api-ls-9-test",
|
|
||||||
Address: strings.TrimPrefix(server.URL, "http://"),
|
|
||||||
AuthToken: "worker-token",
|
|
||||||
LastUsed: time.Now(),
|
|
||||||
}
|
|
||||||
manager.mu.Unlock()
|
|
||||||
|
|
||||||
inst1, err := manager.GetOrCreate(accountID, "rk-1", proxyURL)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.True(t, inst1.remote)
|
|
||||||
require.Equal(t, replicaSlotIndex("rk-1", parseLSReplicaCount()), inst1.Replica)
|
|
||||||
|
|
||||||
inst2, err := manager.GetOrCreate(accountID, "rk-1", proxyURL)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.True(t, inst2.remote)
|
|
||||||
|
|
||||||
mu.Lock()
|
|
||||||
defer mu.Unlock()
|
|
||||||
require.GreaterOrEqual(t, healthCalls, 2)
|
|
||||||
require.GreaterOrEqual(t, readyCalls, 2)
|
|
||||||
require.Equal(t, 1, stateCalls, "state sync should be skipped when the payload hash is unchanged")
|
|
||||||
require.Len(t, stateBodies, 1)
|
|
||||||
|
|
||||||
var synced workerAccountState
|
|
||||||
require.NoError(t, json.Unmarshal(stateBodies[0], &synced))
|
|
||||||
require.True(t, synced.HasToken)
|
|
||||||
require.Equal(t, "ya29.test", synced.AccessToken)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWorkerManagerMaxActiveStopsNewWorkerCreation(t *testing.T) {
|
|
||||||
fakeDocker := &fakeDockerClient{}
|
|
||||||
manager, err := newWorkerManager(workerManagerConfig{
|
|
||||||
Image: "worker:latest",
|
|
||||||
Network: "sub2api-network",
|
|
||||||
DockerSocket: "unix:///var/run/docker.sock",
|
|
||||||
IdleTTL: time.Minute,
|
|
||||||
MaxActive: 1,
|
|
||||||
StartupTimeout: time.Second,
|
|
||||||
RequestTimeout: time.Second,
|
|
||||||
}, fakeDocker)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer manager.Close()
|
|
||||||
|
|
||||||
manager.SetAccountToken("9", "ya29.test", "refresh", time.Now().Add(time.Hour))
|
|
||||||
manager.mu.Lock()
|
|
||||||
manager.workers["existing"] = &workerHandle{ContainerID: "existing", Container: "existing", LastUsed: time.Now()}
|
|
||||||
manager.mu.Unlock()
|
|
||||||
|
|
||||||
_, err = manager.GetOrCreate("9", "rk-new", "socks5h://user:pass@127.0.0.1:1080")
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Contains(t, err.Error(), "limit reached")
|
|
||||||
require.Equal(t, 0, fakeDocker.createCalls)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWorkerManagerReconcileRemovesManagedContainers(t *testing.T) {
|
|
||||||
fakeDocker := &fakeDockerClient{
|
|
||||||
listResp: []container.Summary{
|
|
||||||
{
|
|
||||||
ID: "old-worker-1",
|
|
||||||
Names: []string{"/sub2api-ls-9-deadbeef"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "old-worker-2",
|
|
||||||
Names: []string{"/sub2api-ls-10-beadfeed"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
manager, err := newWorkerManager(workerManagerConfig{
|
|
||||||
Image: "worker:latest",
|
|
||||||
Network: "sub2api-network",
|
|
||||||
DockerSocket: "unix:///var/run/docker.sock",
|
|
||||||
IdleTTL: time.Minute,
|
|
||||||
MaxActive: 4,
|
|
||||||
StartupTimeout: time.Second,
|
|
||||||
RequestTimeout: time.Second,
|
|
||||||
}, fakeDocker)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer manager.Close()
|
|
||||||
|
|
||||||
require.Equal(t, 1, fakeDocker.listCalls)
|
|
||||||
require.ElementsMatch(t, []string{"old-worker-1", "old-worker-2"}, fakeDocker.removedIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFakeDockerClientImplementsFilterAwareList(t *testing.T) {
|
|
||||||
fakeDocker := &fakeDockerClient{}
|
|
||||||
_, err := fakeDocker.ContainerList(context.Background(), container.ListOptions{Filters: filters.NewArgs()})
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestShouldWarnWorkerNotReadySuppressesModelMappingPending(t *testing.T) {
|
|
||||||
require.False(t, shouldWarnWorkerNotReady(http.StatusServiceUnavailable, "worker model mapping not ready for replica 0"))
|
|
||||||
require.True(t, shouldWarnWorkerNotReady(http.StatusServiceUnavailable, "worker access token not configured"))
|
|
||||||
require.True(t, shouldWarnWorkerNotReady(http.StatusBadGateway, "upstream failed"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWorkerManagerWaitForWorkerReadyStopsOnModelMappingUnavailable(t *testing.T) {
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
require.Equal(t, "/readyz", r.URL.Path)
|
|
||||||
w.WriteHeader(http.StatusServiceUnavailable)
|
|
||||||
_, _ = w.Write([]byte(`model mapping unavailable for replica 0: oauth2: "unauthorized_client" "Unauthorized"`))
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
manager, err := newWorkerManager(workerManagerConfig{
|
|
||||||
Image: "worker:latest",
|
|
||||||
Network: "sub2api-network",
|
|
||||||
DockerSocket: "unix:///var/run/docker.sock",
|
|
||||||
IdleTTL: time.Minute,
|
|
||||||
MaxActive: 1,
|
|
||||||
StartupTimeout: time.Second,
|
|
||||||
RequestTimeout: time.Second,
|
|
||||||
}, &fakeDockerClient{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer manager.Close()
|
|
||||||
|
|
||||||
handle := &workerHandle{
|
|
||||||
Container: "sub2api-ls-test",
|
|
||||||
Address: strings.TrimPrefix(server.URL, "http://"),
|
|
||||||
AuthToken: "worker-token",
|
|
||||||
}
|
|
||||||
|
|
||||||
err = manager.waitForWorkerReady(handle, "")
|
|
||||||
require.Error(t, err)
|
|
||||||
require.ErrorIs(t, err, errLSModelMapDenied)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWorkerManagerWaitForWorkerReadyIncludesLastBodyOnTimeout(t *testing.T) {
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
require.Equal(t, "/readyz", r.URL.Path)
|
|
||||||
w.WriteHeader(http.StatusServiceUnavailable)
|
|
||||||
_, _ = w.Write([]byte("worker model mapping not ready for replica 0\n"))
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
manager, err := newWorkerManager(workerManagerConfig{
|
|
||||||
Image: "worker:latest",
|
|
||||||
Network: "sub2api-network",
|
|
||||||
DockerSocket: "unix:///var/run/docker.sock",
|
|
||||||
IdleTTL: time.Minute,
|
|
||||||
MaxActive: 1,
|
|
||||||
StartupTimeout: 100 * time.Millisecond,
|
|
||||||
RequestTimeout: time.Second,
|
|
||||||
}, &fakeDockerClient{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer manager.Close()
|
|
||||||
|
|
||||||
handle := &workerHandle{
|
|
||||||
Container: "sub2api-ls-test",
|
|
||||||
Address: strings.TrimPrefix(server.URL, "http://"),
|
|
||||||
AuthToken: "worker-token",
|
|
||||||
}
|
|
||||||
|
|
||||||
err = manager.waitForWorkerReady(handle, "")
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Contains(t, err.Error(), `last_status=503`)
|
|
||||||
require.Contains(t, err.Error(), `last_body="worker model mapping not ready for replica 0`)
|
|
||||||
}
|
|
||||||
@ -1,374 +0,0 @@
|
|||||||
package lspool
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"log/slog"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type WorkerServerConfig struct {
|
|
||||||
AccountID string
|
|
||||||
AuthToken string
|
|
||||||
ListenAddr string
|
|
||||||
AppRoot string
|
|
||||||
NetworkReadyFile string
|
|
||||||
MaxIdleTime time.Duration
|
|
||||||
HealthInterval time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
type WorkerServer struct {
|
|
||||||
cfg WorkerServerConfig
|
|
||||||
pool *Pool
|
|
||||||
logger *slog.Logger
|
|
||||||
|
|
||||||
mu sync.RWMutex
|
|
||||||
state workerAccountState
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewWorkerServer(cfg WorkerServerConfig) (*WorkerServer, error) {
|
|
||||||
if strings.TrimSpace(cfg.AccountID) == "" {
|
|
||||||
return nil, fmt.Errorf("worker account id is required")
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(cfg.AuthToken) == "" {
|
|
||||||
return nil, fmt.Errorf("worker auth token is required")
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(cfg.ListenAddr) == "" {
|
|
||||||
cfg.ListenAddr = fmt.Sprintf("0.0.0.0:%d", lsWorkerControlPort)
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(cfg.AppRoot) == "" {
|
|
||||||
cfg.AppRoot = "/app/ls"
|
|
||||||
}
|
|
||||||
if cfg.MaxIdleTime <= 0 {
|
|
||||||
cfg.MaxIdleTime = 15 * time.Minute
|
|
||||||
}
|
|
||||||
if cfg.HealthInterval <= 0 {
|
|
||||||
cfg.HealthInterval = 30 * time.Second
|
|
||||||
}
|
|
||||||
|
|
||||||
poolCfg := DefaultConfig()
|
|
||||||
poolCfg.AppRoot = cfg.AppRoot
|
|
||||||
poolCfg.MaxIdleTime = cfg.MaxIdleTime
|
|
||||||
poolCfg.HealthCheckInterval = cfg.HealthInterval
|
|
||||||
|
|
||||||
return &WorkerServer{
|
|
||||||
cfg: cfg,
|
|
||||||
pool: NewPool(poolCfg),
|
|
||||||
logger: slog.Default().With("component", "lsworker"),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewWorkerServerFromEnv() (*WorkerServer, error) {
|
|
||||||
maxIdleTime := 15 * time.Minute
|
|
||||||
if raw := strings.TrimSpace(os.Getenv("LSWORKER_POOL_MAX_IDLE_TIME")); raw != "" {
|
|
||||||
if parsed, err := time.ParseDuration(raw); err == nil {
|
|
||||||
maxIdleTime = parsed
|
|
||||||
}
|
|
||||||
}
|
|
||||||
healthInterval := 30 * time.Second
|
|
||||||
if raw := strings.TrimSpace(os.Getenv("LSWORKER_POOL_HEALTH_INTERVAL")); raw != "" {
|
|
||||||
if parsed, err := time.ParseDuration(raw); err == nil {
|
|
||||||
healthInterval = parsed
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return NewWorkerServer(WorkerServerConfig{
|
|
||||||
AccountID: strings.TrimSpace(os.Getenv("LSWORKER_ACCOUNT_ID")),
|
|
||||||
AuthToken: strings.TrimSpace(os.Getenv("LSWORKER_AUTH_TOKEN")),
|
|
||||||
ListenAddr: strings.TrimSpace(os.Getenv("LSWORKER_LISTEN_ADDR")),
|
|
||||||
AppRoot: strings.TrimSpace(os.Getenv("ANTIGRAVITY_APP_ROOT")),
|
|
||||||
NetworkReadyFile: strings.TrimSpace(os.Getenv("LSWORKER_NETWORK_READY_FILE")),
|
|
||||||
MaxIdleTime: maxIdleTime,
|
|
||||||
HealthInterval: healthInterval,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WorkerServer) Close() {
|
|
||||||
if s.pool != nil {
|
|
||||||
s.pool.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WorkerServer) Handler() http.Handler {
|
|
||||||
mux := http.NewServeMux()
|
|
||||||
mux.HandleFunc("/healthz", s.handleHealthz)
|
|
||||||
mux.HandleFunc("/readyz", s.handleReadyz)
|
|
||||||
mux.HandleFunc("/account/state", s.handleAccountState)
|
|
||||||
mux.HandleFunc("/rpc/unary", s.handleRPCUnary)
|
|
||||||
mux.HandleFunc("/rpc/stream", s.handleRPCStream)
|
|
||||||
return mux
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WorkerServer) handleHealthz(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if !s.authorize(w, r) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
_, _ = w.Write([]byte("ok"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WorkerServer) handleReadyz(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if !s.authorize(w, r) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
routingKey := strings.TrimSpace(r.URL.Query().Get("routing_key"))
|
|
||||||
inst, err := s.ensureReady(r.Context(), routingKey)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
_, _ = w.Write([]byte(fmt.Sprintf("ready replica=%d", inst.Replica)))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WorkerServer) handleAccountState(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if !s.authorize(w, r) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if r.Method != http.MethodPost {
|
|
||||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer r.Body.Close()
|
|
||||||
var payload workerAccountState
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
|
||||||
http.Error(w, "invalid account state payload", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
s.mu.Lock()
|
|
||||||
s.state = *cloneWorkerAccountState(&payload)
|
|
||||||
s.mu.Unlock()
|
|
||||||
s.applyState()
|
|
||||||
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
_, _ = w.Write([]byte("ok"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WorkerServer) handleRPCUnary(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if !s.authorize(w, r) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
service, method, mode, routingKey, ok := parseRPCRequest(w, r)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
inst, err := s.ensureReady(r.Context(), routingKey)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := io.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "read request body failed", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(body) == 0 {
|
|
||||||
body = []byte("{}")
|
|
||||||
}
|
|
||||||
|
|
||||||
var respBody []byte
|
|
||||||
switch mode {
|
|
||||||
case "json":
|
|
||||||
var input any
|
|
||||||
if err := json.Unmarshal(body, &input); err != nil {
|
|
||||||
http.Error(w, "invalid json rpc body", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
respBody, err = inst.CallUnaryJSON(r.Context(), service, method, input)
|
|
||||||
case "proto":
|
|
||||||
respBody, err = inst.CallRPC(r.Context(), service, method, body)
|
|
||||||
default:
|
|
||||||
http.Error(w, "unsupported rpc mode", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
_, _ = w.Write(respBody)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WorkerServer) handleRPCStream(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if !s.authorize(w, r) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
service, method, mode, routingKey, ok := parseRPCRequest(w, r)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
inst, err := s.ensureReady(r.Context(), routingKey)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := io.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "read request body failed", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var resp *http.Response
|
|
||||||
switch mode {
|
|
||||||
case "json":
|
|
||||||
var input any
|
|
||||||
if len(body) == 0 {
|
|
||||||
body = []byte("{}")
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(body, &input); err != nil {
|
|
||||||
http.Error(w, "invalid json rpc body", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
resp, err = inst.StreamRPCJSON(r.Context(), service, method, input)
|
|
||||||
case "proto":
|
|
||||||
resp, err = inst.StreamRPC(r.Context(), service, method, body)
|
|
||||||
default:
|
|
||||||
http.Error(w, "unsupported rpc mode", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
for key, values := range resp.Header {
|
|
||||||
for _, value := range values {
|
|
||||||
w.Header().Add(key, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
w.WriteHeader(resp.StatusCode)
|
|
||||||
_, _ = io.Copy(w, resp.Body)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WorkerServer) authorize(w http.ResponseWriter, r *http.Request) bool {
|
|
||||||
if subtleHeaderEqual(r.Header.Get("X-Worker-Token"), s.cfg.AuthToken) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func subtleHeaderEqual(left, right string) bool {
|
|
||||||
if left == "" || right == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return left == right
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseRPCRequest(w http.ResponseWriter, r *http.Request) (service, method, mode, routingKey string, ok bool) {
|
|
||||||
if r.Method != http.MethodPost {
|
|
||||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
|
||||||
return "", "", "", "", false
|
|
||||||
}
|
|
||||||
query := r.URL.Query()
|
|
||||||
service = strings.TrimSpace(query.Get("service"))
|
|
||||||
method = strings.TrimSpace(query.Get("method"))
|
|
||||||
mode = strings.ToLower(strings.TrimSpace(query.Get("mode")))
|
|
||||||
routingKey = strings.TrimSpace(query.Get("routing_key"))
|
|
||||||
if service == "" || method == "" {
|
|
||||||
http.Error(w, "missing rpc target", http.StatusBadRequest)
|
|
||||||
return "", "", "", "", false
|
|
||||||
}
|
|
||||||
if mode == "" {
|
|
||||||
mode = "proto"
|
|
||||||
}
|
|
||||||
return service, method, mode, routingKey, true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WorkerServer) ensureReady(ctx context.Context, routingKey string) (*Instance, error) {
|
|
||||||
if path := strings.TrimSpace(s.cfg.NetworkReadyFile); path != "" {
|
|
||||||
if _, err := os.Stat(path); err != nil {
|
|
||||||
return nil, fmt.Errorf("worker network not ready: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.applyState()
|
|
||||||
s.mu.RLock()
|
|
||||||
state := cloneWorkerAccountState(&s.state)
|
|
||||||
s.mu.RUnlock()
|
|
||||||
if state == nil || !state.HasToken || strings.TrimSpace(state.AccessToken) == "" {
|
|
||||||
return nil, fmt.Errorf("worker access token not configured")
|
|
||||||
}
|
|
||||||
|
|
||||||
inst, err := s.pool.GetOrCreate(s.cfg.AccountID, routingKey, "")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if inst.HasModelMappingUnavailable() {
|
|
||||||
return nil, fmt.Errorf("%w for replica %d: %s", errLSModelMapDenied, inst.Replica, inst.ModelMappingUnavailableReason())
|
|
||||||
}
|
|
||||||
if inst.HasModelMappingReady() {
|
|
||||||
return inst, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
modelCtx, cancel := context.WithTimeout(ctx, lsModelConfigTimeout)
|
|
||||||
defer cancel()
|
|
||||||
_ = modelCtx
|
|
||||||
if !RefreshModelMapping(inst) {
|
|
||||||
if inst.HasModelMappingUnavailable() {
|
|
||||||
return nil, fmt.Errorf("%w for replica %d: %s", errLSModelMapDenied, inst.Replica, inst.ModelMappingUnavailableReason())
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("worker model mapping not ready for replica %d", inst.Replica)
|
|
||||||
}
|
|
||||||
return inst, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WorkerServer) applyState() {
|
|
||||||
s.mu.RLock()
|
|
||||||
state := cloneWorkerAccountState(&s.state)
|
|
||||||
s.mu.RUnlock()
|
|
||||||
if state == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if state.HasToken {
|
|
||||||
expiresAt := time.Time{}
|
|
||||||
if state.ExpiresAt != nil {
|
|
||||||
expiresAt = state.ExpiresAt.UTC()
|
|
||||||
}
|
|
||||||
s.pool.SetAccountToken(s.cfg.AccountID, state.AccessToken, state.RefreshToken, expiresAt)
|
|
||||||
}
|
|
||||||
if state.HasModelCredits {
|
|
||||||
s.pool.SetAccountModelCredits(s.cfg.AccountID, state.UseAICredits, state.AvailableCredits, state.MinimumCreditAmount)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func workerHTTPServer(listenAddr string, handler http.Handler) *http.Server {
|
|
||||||
return &http.Server{
|
|
||||||
Addr: listenAddr,
|
|
||||||
Handler: handler,
|
|
||||||
ReadHeaderTimeout: 10 * time.Second,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func workerExitCode(err error) int {
|
|
||||||
if err == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseWorkerControlPort() int {
|
|
||||||
raw := strings.TrimSpace(os.Getenv("LSWORKER_CONTROL_PORT"))
|
|
||||||
if raw == "" {
|
|
||||||
return lsWorkerControlPort
|
|
||||||
}
|
|
||||||
port, err := strconv.Atoi(raw)
|
|
||||||
if err != nil || port < 1 {
|
|
||||||
return lsWorkerControlPort
|
|
||||||
}
|
|
||||||
return port
|
|
||||||
}
|
|
||||||
@ -1,11 +1,20 @@
|
|||||||
package routes
|
package routes
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
anthropicEventLoggingURL = "https://api.anthropic.com/api/event_logging/batch"
|
||||||
|
eventLoggingForwardTimeout = 8 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
// RegisterCommonRoutes 注册通用路由(健康检查、状态等)
|
// RegisterCommonRoutes 注册通用路由(健康检查、状态等)
|
||||||
func RegisterCommonRoutes(r *gin.Engine) {
|
func RegisterCommonRoutes(r *gin.Engine) {
|
||||||
// 健康检查
|
// 健康检查
|
||||||
@ -13,8 +22,36 @@ func RegisterCommonRoutes(r *gin.Engine) {
|
|||||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||||
})
|
})
|
||||||
|
|
||||||
// Claude Code 遥测日志(忽略,直接返回200)
|
// Claude Code 遥测日志:清理敏感字段后转发给 Anthropic。
|
||||||
|
// 删除 baseUrl/gateway 字段防止网关地址暴露(见 FINGERPRINT_SECURITY_REPORT.md §GAP-1/2)。
|
||||||
|
// 转发而非丢弃,避免"高流量零遥测"异常被检测。
|
||||||
r.POST("/api/event_logging/batch", func(c *gin.Context) {
|
r.POST("/api/event_logging/batch", func(c *gin.Context) {
|
||||||
|
body, err := io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil || len(body) == 0 {
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sanitized := sanitizeEventBatch(body)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), eventLoggingForwardTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, anthropicEventLoggingURL, bytes.NewReader(sanitized))
|
||||||
|
if err != nil {
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
// 透传客户端的 Authorization header(OAuth Bearer token)
|
||||||
|
if auth := c.GetHeader("Authorization"); auth != "" {
|
||||||
|
req.Header.Set("Authorization", auth)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err == nil {
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
c.Status(http.StatusOK)
|
c.Status(http.StatusOK)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@ -4,6 +4,9 @@ import (
|
|||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
@ -12,7 +15,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Attribution block constants matching real Claude Code 2.1.88.
|
// Attribution block constants matching real Claude Code 2.1.89.
|
||||||
// Source: src/constants/system.ts + src/utils/fingerprint.ts
|
// Source: src/constants/system.ts + src/utils/fingerprint.ts
|
||||||
const (
|
const (
|
||||||
// fingerprintSalt must match the hardcoded salt in the real CLI.
|
// fingerprintSalt must match the hardcoded salt in the real CLI.
|
||||||
@ -81,11 +84,10 @@ func extractFirstUserMessageText(body []byte) string {
|
|||||||
// Source: extracted/src/constants/system.ts:73-95
|
// Source: extracted/src/constants/system.ts:73-95
|
||||||
func buildAttributionBlock(cliVersion, fingerprint string) string {
|
func buildAttributionBlock(cliVersion, fingerprint string) string {
|
||||||
version := cliVersion + "." + fingerprint
|
version := cliVersion + "." + fingerprint
|
||||||
// 注意:cch 字段由 Bun 的 NATIVE_CLIENT_ATTESTATION 编译时 feature 控制。
|
// 2.1.89 起 cch=00000 出现在所有安装模式(含 npm 版),不再只限于原生二进制。
|
||||||
// npm 安装版本(非原生二进制)此 feature 为 false,所以不包含 cch 字段。
|
// 原生二进制由 Bun 的 Zig 层在运行时将 00000 替换为真实 attestation hash;
|
||||||
// 只有原生二进制安装(Bun 打包)才会有 cch,且其值会被 Bun 的 Zig 层替换为真实 hash。
|
// 普通安装版保持 00000 占位符不变。
|
||||||
// 我们模拟 npm 安装版本的行为:不包含 cch。
|
return fmt.Sprintf("x-anthropic-billing-header: cc_version=%s; cc_entrypoint=cli; cch=00000;", version)
|
||||||
return fmt.Sprintf("x-anthropic-billing-header: cc_version=%s; cc_entrypoint=cli;", version)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// injectAttributionBlock prepends the x-anthropic-billing-header attribution block
|
// injectAttributionBlock prepends the x-anthropic-billing-header attribution block
|
||||||
@ -163,20 +165,89 @@ func injectAttributionBlock(body []byte, cliVersion string) []byte {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateSessionIDForAccount generates a deterministic per-account session UUID
|
// cliSessionEntry holds a cached session UUID with an expiration time.
|
||||||
// that remains stable within a process-like timeframe.
|
type cliSessionEntry struct {
|
||||||
// Uses instanceSalt + accountID to ensure uniqueness across sub2api instances.
|
id string
|
||||||
func generateSessionIDForAccount(instanceSalt string, accountID int64) string {
|
expiresAt time.Time
|
||||||
// Use a per-account stable UUID (like real CLI's per-process UUID).
|
}
|
||||||
// We use accountID as the base — each account gets a different "session".
|
|
||||||
seed := fmt.Sprintf("session:%s:%d", instanceSalt, accountID)
|
// cliSessionCache stores per-account session UUIDs that rotate on a TTL.
|
||||||
hash := sha256.Sum256([]byte(seed))
|
// Real CLI creates a new random UUID per process invocation; we approximate
|
||||||
sessionUUID, err := uuid.FromBytes(hash[:16])
|
// this by rotating every 30-60 minutes (jittered per account).
|
||||||
if err != nil {
|
var (
|
||||||
return uuid.New().String()
|
cliSessionCache = make(map[int64]cliSessionEntry)
|
||||||
}
|
cliSessionCacheMu sync.Mutex
|
||||||
// Set UUID v4 variant
|
)
|
||||||
sessionUUID[6] = (sessionUUID[6] & 0x0f) | 0x40
|
|
||||||
sessionUUID[8] = (sessionUUID[8] & 0x3f) | 0x80
|
// sessionTTLBase is the base TTL for session ID rotation.
|
||||||
return sessionUUID.String()
|
const sessionTTLBase = 30 * time.Minute
|
||||||
|
|
||||||
|
// generateSessionIDForAccount returns a per-account session UUID that rotates
|
||||||
|
// periodically. Each account gets a random TTL jitter (0-30 min on top of
|
||||||
|
// the 30 min base) so accounts don't all rotate simultaneously.
|
||||||
|
func generateSessionIDForAccount(instanceSalt string, accountID int64) string {
|
||||||
|
cliSessionCacheMu.Lock()
|
||||||
|
defer cliSessionCacheMu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
if entry, ok := cliSessionCache[accountID]; ok && now.Before(entry.expiresAt) {
|
||||||
|
return entry.id
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute per-account jitter from a hash so the same account always gets
|
||||||
|
// the same jitter within a process (avoids re-rolling on every rotation).
|
||||||
|
jitterSeed := fmt.Sprintf("jitter:%s:%d", instanceSalt, accountID)
|
||||||
|
h := sha256.Sum256([]byte(jitterSeed))
|
||||||
|
jitterMinutes := int(h[0]) % 31 // 0-30 minutes
|
||||||
|
ttl := sessionTTLBase + time.Duration(jitterMinutes)*time.Minute
|
||||||
|
|
||||||
|
newID := uuid.New().String()
|
||||||
|
cliSessionCache[accountID] = cliSessionEntry{
|
||||||
|
id: newID,
|
||||||
|
expiresAt: now.Add(ttl),
|
||||||
|
}
|
||||||
|
return newID
|
||||||
|
}
|
||||||
|
|
||||||
|
// reUserHome matches /Users/<username>/ or /home/<username>/ path segments.
|
||||||
|
// Captures the prefix (/Users/ or /home/) so we can preserve it while replacing the username.
|
||||||
|
var reUserHome = regexp.MustCompile(`(/(Users|home)/)[^/\s"']+/`)
|
||||||
|
|
||||||
|
// reEnvLine matches lines of the form "Key: value" for the environment block
|
||||||
|
// fields injected by Claude Code's CLAUDE.md / sysprompt machinery.
|
||||||
|
var reEnvLine = regexp.MustCompile(`(?m)^(Platform|Shell|OS Version|Working directory):.*$`)
|
||||||
|
|
||||||
|
// canonicalEnvValues maps environment block keys to their canonical replacements.
|
||||||
|
// Values mirror cc-gateway's prompt_env config and represent a stock macOS dev machine.
|
||||||
|
var canonicalEnvValues = map[string]string{
|
||||||
|
"Platform": "Platform: darwin",
|
||||||
|
"Shell": "Shell: zsh",
|
||||||
|
"OS Version": "OS Version: Darwin 24.4.0",
|
||||||
|
"Working directory": "Working directory: /Users/user/project",
|
||||||
|
}
|
||||||
|
|
||||||
|
// NormalizeSystemPromptEnv rewrites environment-specific fields in a system
|
||||||
|
// prompt text block to canonical values, preventing real machine fingerprinting.
|
||||||
|
//
|
||||||
|
// Handles two classes of leakage (matching cc-gateway rewriter.ts:rewritePromptText):
|
||||||
|
// 1. "Platform: Windows / Linux / Darwin 25.x" → canonical darwin/zsh/Darwin 24.4.0
|
||||||
|
// 2. "/Users/alice/" or "/home/bob/" → "/Users/user/"
|
||||||
|
//
|
||||||
|
// Only called on system prompt text blocks, never on user message content.
|
||||||
|
func NormalizeSystemPromptEnv(text string) string {
|
||||||
|
// Replace env-info lines with canonical values
|
||||||
|
text = reEnvLine.ReplaceAllStringFunc(text, func(line string) string {
|
||||||
|
for key, canonical := range canonicalEnvValues {
|
||||||
|
if len(line) >= len(key) && line[:len(key)] == key {
|
||||||
|
return canonical
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return line
|
||||||
|
})
|
||||||
|
|
||||||
|
// Redact real usernames in home directory paths
|
||||||
|
// e.g. /Users/alice/project -> /Users/user/project
|
||||||
|
text = reUserHome.ReplaceAllString(text, "${1}user/")
|
||||||
|
|
||||||
|
return text
|
||||||
}
|
}
|
||||||
|
|||||||
@ -895,6 +895,9 @@ func sanitizeSystemText(text string) string {
|
|||||||
"You are OpenCode, the best coding agent on the planet.",
|
"You are OpenCode, the best coding agent on the planet.",
|
||||||
strings.TrimSpace(claudeCodeSystemPrompt),
|
strings.TrimSpace(claudeCodeSystemPrompt),
|
||||||
)
|
)
|
||||||
|
// Normalize environment block fields (Platform/Shell/OS Version/Working directory)
|
||||||
|
// to canonical values so different client machines don't create fingerprint divergence.
|
||||||
|
text = NormalizeSystemPromptEnv(text)
|
||||||
return text
|
return text
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -5773,7 +5776,7 @@ func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string)
|
|||||||
return claude.HaikuBetaHeader
|
return claude.HaikuBetaHeader
|
||||||
}
|
}
|
||||||
|
|
||||||
return claude.DefaultBetaHeader
|
return claude.GetOAuthBetaHeader(modelID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func requestNeedsBetaFeatures(body []byte) bool {
|
func requestNeedsBetaFeatures(body []byte) bool {
|
||||||
@ -5790,10 +5793,7 @@ func requestNeedsBetaFeatures(body []byte) bool {
|
|||||||
|
|
||||||
func defaultAPIKeyBetaHeader(body []byte) string {
|
func defaultAPIKeyBetaHeader(body []byte) string {
|
||||||
modelID := gjson.GetBytes(body, "model").String()
|
modelID := gjson.GetBytes(body, "model").String()
|
||||||
if strings.Contains(strings.ToLower(modelID), "haiku") {
|
return claude.GetAPIKeyBetaHeader(modelID)
|
||||||
return claude.APIKeyHaikuBetaHeader
|
|
||||||
}
|
|
||||||
return claude.APIKeyBetaHeader
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyClaudeOAuthHeaderDefaults(req *http.Request) {
|
func applyClaudeOAuthHeaderDefaults(req *http.Request) {
|
||||||
|
|||||||
@ -26,7 +26,7 @@ var (
|
|||||||
|
|
||||||
// 默认指纹值(当客户端未提供时使用)
|
// 默认指纹值(当客户端未提供时使用)
|
||||||
var defaultFingerprint = Fingerprint{
|
var defaultFingerprint = Fingerprint{
|
||||||
UserAgent: "claude-cli/2.1.88 (external, cli)",
|
UserAgent: "claude-cli/2.1.89 (external, cli)",
|
||||||
StainlessLang: "js",
|
StainlessLang: "js",
|
||||||
StainlessPackageVersion: "0.74.0",
|
StainlessPackageVersion: "0.74.0",
|
||||||
StainlessOS: "MacOS",
|
StainlessOS: "MacOS",
|
||||||
|
|||||||
@ -1,225 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"log/slog"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/lspool"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
defaultLSPoolBootstrapConcurrency = 4
|
|
||||||
)
|
|
||||||
|
|
||||||
type lsBootstrapAccountReader interface {
|
|
||||||
ListByPlatform(ctx context.Context, platform string) ([]Account, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// LSPoolBootstrapService pre-creates LS workers for eligible Antigravity accounts on startup.
|
|
||||||
type LSPoolBootstrapService struct {
|
|
||||||
accountReader lsBootstrapAccountReader
|
|
||||||
backend lspool.Backend
|
|
||||||
cfg *config.Config
|
|
||||||
logger *slog.Logger
|
|
||||||
|
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
|
|
||||||
once sync.Once
|
|
||||||
wg sync.WaitGroup
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewLSPoolBootstrapService(accountReader lsBootstrapAccountReader, backend lspool.Backend, cfg *config.Config) *LSPoolBootstrapService {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
return &LSPoolBootstrapService{
|
|
||||||
accountReader: accountReader,
|
|
||||||
backend: backend,
|
|
||||||
cfg: cfg,
|
|
||||||
logger: slog.Default().With("component", "service.lspool_bootstrap"),
|
|
||||||
ctx: ctx,
|
|
||||||
cancel: cancel,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProvideLSPoolBootstrapService creates and starts the LS pool bootstrap worker.
|
|
||||||
func ProvideLSPoolBootstrapService(accountRepo AccountRepository, cfg *config.Config) *LSPoolBootstrapService {
|
|
||||||
svc := NewLSPoolBootstrapService(accountRepo, lspool.GlobalPool(cfg), cfg)
|
|
||||||
svc.Start()
|
|
||||||
return svc
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *LSPoolBootstrapService) Start() {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.once.Do(func() {
|
|
||||||
if s.backend == nil {
|
|
||||||
if lspool.IsLSModeEnabled() {
|
|
||||||
s.logger.Warn("startup bootstrap skipped: ls backend unavailable")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer s.wg.Done()
|
|
||||||
s.bootstrap(s.ctx)
|
|
||||||
}()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *LSPoolBootstrapService) Stop() {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.cancel()
|
|
||||||
s.wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *LSPoolBootstrapService) bootstrap(ctx context.Context) {
|
|
||||||
if s.backend == nil || s.accountReader == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
accounts, err := s.accountReader.ListByPlatform(ctx, PlatformAntigravity)
|
|
||||||
if err != nil {
|
|
||||||
s.logger.Warn("load antigravity accounts for ls bootstrap failed", "error", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
candidates := make([]Account, 0, len(accounts))
|
|
||||||
for i := range accounts {
|
|
||||||
if shouldBootstrapLSPoolAccount(&accounts[i], now) {
|
|
||||||
candidates = append(candidates, accounts[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(candidates) == 0 {
|
|
||||||
s.logger.Info("startup bootstrap skipped: no eligible antigravity accounts")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
s.logger.Info("starting ls worker bootstrap",
|
|
||||||
"accounts_total", len(accounts),
|
|
||||||
"accounts_eligible", len(candidates),
|
|
||||||
"concurrency", s.bootstrapConcurrency())
|
|
||||||
|
|
||||||
var (
|
|
||||||
mu sync.Mutex
|
|
||||||
started int
|
|
||||||
failed int
|
|
||||||
)
|
|
||||||
sem := make(chan struct{}, s.bootstrapConcurrency())
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
|
|
||||||
loop:
|
|
||||||
for i := range candidates {
|
|
||||||
account := candidates[i]
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
break loop
|
|
||||||
case sem <- struct{}{}:
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Add(1)
|
|
||||||
go func(account Account) {
|
|
||||||
defer wg.Done()
|
|
||||||
defer func() { <-sem }()
|
|
||||||
|
|
||||||
if err := s.bootstrapAccount(&account); err != nil {
|
|
||||||
mu.Lock()
|
|
||||||
failed++
|
|
||||||
mu.Unlock()
|
|
||||||
s.logger.Warn("bootstrap ls worker failed", "account_id", account.ID, "error", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
mu.Lock()
|
|
||||||
started++
|
|
||||||
mu.Unlock()
|
|
||||||
s.logger.Info("bootstrap ls worker ready", "account_id", account.ID)
|
|
||||||
}(account)
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
s.logger.Info("ls worker bootstrap completed",
|
|
||||||
"accounts_total", len(accounts),
|
|
||||||
"accounts_eligible", len(candidates),
|
|
||||||
"workers_ready", started,
|
|
||||||
"workers_failed", failed,
|
|
||||||
"canceled", ctx.Err() != nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *LSPoolBootstrapService) bootstrapAccount(account *Account) error {
|
|
||||||
if s.backend == nil {
|
|
||||||
return fmt.Errorf("ls backend unavailable")
|
|
||||||
}
|
|
||||||
if account == nil {
|
|
||||||
return fmt.Errorf("account is nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
accountKey := strconv.FormatInt(account.ID, 10)
|
|
||||||
accessToken := strings.TrimSpace(account.GetCredential("access_token"))
|
|
||||||
if accessToken == "" {
|
|
||||||
return fmt.Errorf("missing access token")
|
|
||||||
}
|
|
||||||
refreshToken := strings.TrimSpace(account.GetCredential("refresh_token"))
|
|
||||||
|
|
||||||
expiresAt := time.Time{}
|
|
||||||
if ts := account.GetCredentialAsTime("expires_at"); ts != nil {
|
|
||||||
expiresAt = ts.UTC()
|
|
||||||
}
|
|
||||||
|
|
||||||
s.backend.SetAccountToken(accountKey, accessToken, refreshToken, expiresAt)
|
|
||||||
availableCredits, minimumCreditAmount := resolveLSPoolModelCreditsState(account)
|
|
||||||
s.backend.SetAccountModelCredits(accountKey, account.IsOveragesEnabled(), availableCredits, minimumCreditAmount)
|
|
||||||
|
|
||||||
proxyURL := ""
|
|
||||||
if account.Proxy != nil {
|
|
||||||
proxyURL = account.Proxy.URL()
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := s.backend.GetOrCreate(accountKey, "", proxyURL); err != nil {
|
|
||||||
return fmt.Errorf("get or create ls worker: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *LSPoolBootstrapService) bootstrapConcurrency() int {
|
|
||||||
parallelism := defaultLSPoolBootstrapConcurrency
|
|
||||||
if s.cfg != nil && s.cfg.Gateway.AntigravityLSWorker.MaxActive > 0 && s.cfg.Gateway.AntigravityLSWorker.MaxActive < parallelism {
|
|
||||||
parallelism = s.cfg.Gateway.AntigravityLSWorker.MaxActive
|
|
||||||
}
|
|
||||||
if parallelism < 1 {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
return parallelism
|
|
||||||
}
|
|
||||||
|
|
||||||
func shouldBootstrapLSPoolAccount(account *Account, now time.Time) bool {
|
|
||||||
if account == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if account.Platform != PlatformAntigravity {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if account.Type != AccountTypeOAuth {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if account.Status != StatusActive || !account.Schedulable {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if account.AutoPauseOnExpired && account.ExpiresAt != nil && !now.Before(*account.ExpiresAt) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(account.GetCredential("access_token")) == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return strings.TrimSpace(account.GetCredential("project_id")) != ""
|
|
||||||
}
|
|
||||||
@ -1,262 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/lspool"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
type fakeLSBootstrapAccountReader struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
accounts []Account
|
|
||||||
err error
|
|
||||||
platforms []string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeLSBootstrapAccountReader) ListByPlatform(_ context.Context, platform string) ([]Account, error) {
|
|
||||||
f.mu.Lock()
|
|
||||||
f.platforms = append(f.platforms, platform)
|
|
||||||
accounts := append([]Account(nil), f.accounts...)
|
|
||||||
err := f.err
|
|
||||||
f.mu.Unlock()
|
|
||||||
return accounts, err
|
|
||||||
}
|
|
||||||
|
|
||||||
type fakeLSPoolBackend struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
tokenCalls map[string]fakeLSPoolTokenCall
|
|
||||||
creditCalls map[string]fakeLSPoolCreditCall
|
|
||||||
getCalls []fakeLSPoolGetCall
|
|
||||||
getErrs map[string]error
|
|
||||||
}
|
|
||||||
|
|
||||||
type fakeLSPoolTokenCall struct {
|
|
||||||
AccessToken string
|
|
||||||
RefreshToken string
|
|
||||||
ExpiresAt time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
type fakeLSPoolCreditCall struct {
|
|
||||||
UseAICredits bool
|
|
||||||
AvailableCredits *int32
|
|
||||||
MinimumCreditAmount *int32
|
|
||||||
}
|
|
||||||
|
|
||||||
type fakeLSPoolGetCall struct {
|
|
||||||
AccountID string
|
|
||||||
RoutingKey string
|
|
||||||
ProxyURL string
|
|
||||||
}
|
|
||||||
|
|
||||||
func newFakeLSPoolBackend() *fakeLSPoolBackend {
|
|
||||||
return &fakeLSPoolBackend{
|
|
||||||
tokenCalls: make(map[string]fakeLSPoolTokenCall),
|
|
||||||
creditCalls: make(map[string]fakeLSPoolCreditCall),
|
|
||||||
getErrs: make(map[string]error),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeLSPoolBackend) GetOrCreate(accountID, routingKey string, proxyURL ...string) (*lspool.Instance, error) {
|
|
||||||
rawProxy := ""
|
|
||||||
if len(proxyURL) > 0 {
|
|
||||||
rawProxy = proxyURL[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
f.mu.Lock()
|
|
||||||
defer f.mu.Unlock()
|
|
||||||
f.getCalls = append(f.getCalls, fakeLSPoolGetCall{
|
|
||||||
AccountID: accountID,
|
|
||||||
RoutingKey: routingKey,
|
|
||||||
ProxyURL: rawProxy,
|
|
||||||
})
|
|
||||||
if err := f.getErrs[accountID]; err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &lspool.Instance{AccountID: accountID}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeLSPoolBackend) SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time) {
|
|
||||||
f.mu.Lock()
|
|
||||||
defer f.mu.Unlock()
|
|
||||||
f.tokenCalls[accountID] = fakeLSPoolTokenCall{
|
|
||||||
AccessToken: accessToken,
|
|
||||||
RefreshToken: refreshToken,
|
|
||||||
ExpiresAt: expiresAt,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeLSPoolBackend) SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32) {
|
|
||||||
f.mu.Lock()
|
|
||||||
defer f.mu.Unlock()
|
|
||||||
f.creditCalls[accountID] = fakeLSPoolCreditCall{
|
|
||||||
UseAICredits: useAICredits,
|
|
||||||
AvailableCredits: copyInt32Ptr(availableCredits),
|
|
||||||
MinimumCreditAmount: copyInt32Ptr(minimumCreditAmountForUsage),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeLSPoolBackend) Stats() map[string]any { return nil }
|
|
||||||
|
|
||||||
func (f *fakeLSPoolBackend) Close() {}
|
|
||||||
|
|
||||||
func copyInt32Ptr(v *int32) *int32 {
|
|
||||||
if v == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
cp := *v
|
|
||||||
return &cp
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLSPoolBootstrapServiceBootstrapEligibleAccounts(t *testing.T) {
|
|
||||||
expiresAt := time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second)
|
|
||||||
expiredAt := time.Now().Add(-2 * time.Hour)
|
|
||||||
reader := &fakeLSBootstrapAccountReader{
|
|
||||||
accounts: []Account{
|
|
||||||
{
|
|
||||||
ID: 101,
|
|
||||||
Platform: PlatformAntigravity,
|
|
||||||
Type: AccountTypeOAuth,
|
|
||||||
Status: StatusActive,
|
|
||||||
Schedulable: true,
|
|
||||||
Credentials: map[string]any{
|
|
||||||
"access_token": "token-101",
|
|
||||||
"refresh_token": "refresh-101",
|
|
||||||
"expires_at": expiresAt.Format(time.RFC3339),
|
|
||||||
"project_id": "proj-101",
|
|
||||||
},
|
|
||||||
Extra: map[string]any{
|
|
||||||
"allow_overages": true,
|
|
||||||
"ai_credits": []any{
|
|
||||||
map[string]any{
|
|
||||||
"credit_type": "GOOGLE_ONE_AI",
|
|
||||||
"amount": 120,
|
|
||||||
"minimum_balance": 55,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Proxy: &Proxy{
|
|
||||||
Protocol: "socks5h",
|
|
||||||
Host: "127.0.0.1",
|
|
||||||
Port: 1080,
|
|
||||||
Username: "alice",
|
|
||||||
Password: "secret",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: 102,
|
|
||||||
Platform: PlatformAntigravity,
|
|
||||||
Type: AccountTypeOAuth,
|
|
||||||
Status: StatusActive,
|
|
||||||
Schedulable: false,
|
|
||||||
Credentials: map[string]any{"access_token": "token-102", "project_id": "proj-102"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: 103,
|
|
||||||
Platform: PlatformAntigravity,
|
|
||||||
Type: AccountTypeOAuth,
|
|
||||||
Status: StatusActive,
|
|
||||||
Schedulable: true,
|
|
||||||
Credentials: map[string]any{"access_token": "token-103"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: 104,
|
|
||||||
Platform: PlatformAntigravity,
|
|
||||||
Type: AccountTypeOAuth,
|
|
||||||
Status: StatusActive,
|
|
||||||
Schedulable: true,
|
|
||||||
AutoPauseOnExpired: true,
|
|
||||||
ExpiresAt: &expiredAt,
|
|
||||||
Credentials: map[string]any{"access_token": "token-104", "project_id": "proj-104"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: 106,
|
|
||||||
Platform: PlatformAntigravity,
|
|
||||||
Type: AccountTypeUpstream,
|
|
||||||
Status: StatusActive,
|
|
||||||
Schedulable: true,
|
|
||||||
Credentials: map[string]any{"access_token": "token-106", "project_id": "proj-106"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: 105,
|
|
||||||
Platform: PlatformOpenAI,
|
|
||||||
Status: StatusActive,
|
|
||||||
Schedulable: true,
|
|
||||||
Credentials: map[string]any{"access_token": "token-105"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
backend := newFakeLSPoolBackend()
|
|
||||||
svc := NewLSPoolBootstrapService(reader, backend, &config.Config{
|
|
||||||
Gateway: config.GatewayConfig{
|
|
||||||
AntigravityLSWorker: config.GatewayAntigravityLSWorkerConfig{MaxActive: 3},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
svc.bootstrap(context.Background())
|
|
||||||
|
|
||||||
require.Equal(t, []string{PlatformAntigravity}, reader.platforms)
|
|
||||||
|
|
||||||
require.Len(t, backend.getCalls, 1)
|
|
||||||
require.Equal(t, fakeLSPoolGetCall{
|
|
||||||
AccountID: "101",
|
|
||||||
RoutingKey: "",
|
|
||||||
ProxyURL: "socks5h://alice:secret@127.0.0.1:1080",
|
|
||||||
}, backend.getCalls[0])
|
|
||||||
|
|
||||||
tokenCall, ok := backend.tokenCalls["101"]
|
|
||||||
require.True(t, ok)
|
|
||||||
require.Equal(t, "token-101", tokenCall.AccessToken)
|
|
||||||
require.Equal(t, "refresh-101", tokenCall.RefreshToken)
|
|
||||||
require.Equal(t, expiresAt, tokenCall.ExpiresAt)
|
|
||||||
|
|
||||||
creditCall, ok := backend.creditCalls["101"]
|
|
||||||
require.True(t, ok)
|
|
||||||
require.True(t, creditCall.UseAICredits)
|
|
||||||
require.NotNil(t, creditCall.AvailableCredits)
|
|
||||||
require.Equal(t, int32(120), *creditCall.AvailableCredits)
|
|
||||||
require.NotNil(t, creditCall.MinimumCreditAmount)
|
|
||||||
require.Equal(t, int32(55), *creditCall.MinimumCreditAmount)
|
|
||||||
|
|
||||||
require.NotContains(t, backend.tokenCalls, "102")
|
|
||||||
require.NotContains(t, backend.tokenCalls, "103")
|
|
||||||
require.NotContains(t, backend.tokenCalls, "104")
|
|
||||||
require.NotContains(t, backend.tokenCalls, "106")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLSPoolBootstrapServiceBootstrapContinuesOnWorkerFailure(t *testing.T) {
|
|
||||||
reader := &fakeLSBootstrapAccountReader{
|
|
||||||
accounts: []Account{
|
|
||||||
{
|
|
||||||
ID: 201,
|
|
||||||
Platform: PlatformAntigravity,
|
|
||||||
Type: AccountTypeOAuth,
|
|
||||||
Status: StatusActive,
|
|
||||||
Schedulable: true,
|
|
||||||
Credentials: map[string]any{"access_token": "token-201", "project_id": "proj-201"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: 202,
|
|
||||||
Platform: PlatformAntigravity,
|
|
||||||
Type: AccountTypeOAuth,
|
|
||||||
Status: StatusActive,
|
|
||||||
Schedulable: true,
|
|
||||||
Credentials: map[string]any{"access_token": "token-202", "project_id": "proj-202"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
backend := newFakeLSPoolBackend()
|
|
||||||
backend.getErrs["201"] = errors.New("create failed")
|
|
||||||
|
|
||||||
svc := NewLSPoolBootstrapService(reader, backend, &config.Config{})
|
|
||||||
svc.bootstrap(context.Background())
|
|
||||||
|
|
||||||
require.Len(t, backend.getCalls, 2)
|
|
||||||
require.Contains(t, backend.tokenCalls, "201")
|
|
||||||
require.Contains(t, backend.tokenCalls, "202")
|
|
||||||
}
|
|
||||||
@ -1,21 +0,0 @@
|
|||||||
-----BEGIN CERTIFICATE-----
|
|
||||||
MIIDXTCCAkWgAwIBAgIUVoRddTlTFh3+shRe6g4kSLo2n0MwDQYJKoZIhvcNAQEL
|
|
||||||
BQAwSTESMBAGA1UEAwwJbG9jYWxob3N0MRYwFAYDVQQKDA1FTkFCTEVTIEhUVFAy
|
|
||||||
MRswGQYDVQQLDBJidW5kbGVkIG9uIHB1cnBvc2UwHhcNMjUwOTA0MjA1NTA0WhcN
|
|
||||||
MjYwOTA0MjA1NTA0WjBJMRIwEAYDVQQDDAlsb2NhbGhvc3QxFjAUBgNVBAoMDUVO
|
|
||||||
QUJMRVMgSFRUUDIxGzAZBgNVBAsMEmJ1bmRsZWQgb24gcHVycG9zZTCCASIwDQYJ
|
|
||||||
KoZIhvcNAQEBBQADggEPADCCAQoCggEBAJVpU6IyIMgwB6CJHkOeEAgYtzvyH6fM
|
|
||||||
lkZSbemTrD9RCWZ4Fati1/6vbbMyWsM2XNJQMhJo0JTEoLDddN1iV/xGJCO/3dgw
|
|
||||||
4+wLqqEeck4R1pHygCkb40TycmyygSWsidkEUH0xp51nCapIdPr/WL6O+Gbpl6DA
|
|
||||||
onerUmWIO39VG2SpV7x3iXZOSbIGMsOiNZBmGwBZcL8ZejBIDjwvNjnX/d2tejH5
|
|
||||||
/Mo4KVEXl5jsqaNbDIkhSs5BXtCMhoi1dqt75M8FyuNZd50AGFSa9Lj6pHTpwepD
|
|
||||||
k2x4h+czPcvscF7TQG31TK1VYFPUThDim+by0+LQKkpy/UGVWnbC4dsCAwEAAaM9
|
|
||||||
MDswGgYDVR0RBBMwEYIJbG9jYWxob3N0hwR/AAABMB0GA1UdDgQWBBSonSKmHCVt
|
|
||||||
yBoVH1xEb3vtCng80DANBgkqhkiG9w0BAQsFAAOCAQEAinBO/uYe8ExHeiskt2P/
|
|
||||||
Oxkd5sHSY9deLVuyX/TFnUEfktMfYKM2Juy+MfH4vfrcEhYkYJJcm25UGrtiT0Jh
|
|
||||||
bUooDkR53549Xzg/70HU/ls1eNIe0zYqmS12H5W4Q1LAWTVpePscB4dgOrps6xIk
|
|
||||||
Q4nlF7dst93E3swAe81rgCEd7VZEZy5VQcE9K+CIZXaAUJwUAsAtJbrP+5JMe9pt
|
|
||||||
q52Zq5ZVkBS+4xeaMrasN0iTgsS4Lxo2a0GFDIJ84V66oeX7a5SXfSNn7rMVIDai
|
|
||||||
KNZ2Cf2xNXUwq25Z6tjpQCqwYn3SE8b/Yi6fFZmy5D8kmY7dMh8ghVOc7rD+Vsk6
|
|
||||||
/Q==
|
|
||||||
-----END CERTIFICATE-----
|
|
||||||
Binary file not shown.
Binary file not shown.
@ -1,70 +0,0 @@
|
|||||||
#!/bin/sh
|
|
||||||
set -eu
|
|
||||||
|
|
||||||
PROXY_HOST="${LSWORKER_PROXY_HOST:-}"
|
|
||||||
PROXY_PORT="${LSWORKER_PROXY_PORT:-1080}"
|
|
||||||
PROXY_USER="${LSWORKER_PROXY_USER:-}"
|
|
||||||
PROXY_PASS="${LSWORKER_PROXY_PASS:-}"
|
|
||||||
CONTROL_PORT="${LSWORKER_CONTROL_PORT:-18081}"
|
|
||||||
REDSOCKS_PORT="${LSWORKER_REDSOCKS_PORT:-12345}"
|
|
||||||
NETWORK_READY_FILE="${LSWORKER_NETWORK_READY_FILE:-/run/lsworker/network-ready}"
|
|
||||||
|
|
||||||
mkdir -p "$(dirname "${NETWORK_READY_FILE}")"
|
|
||||||
|
|
||||||
if [ -z "${PROXY_HOST}" ]; then
|
|
||||||
echo "LSWORKER_PROXY_HOST is required" >&2
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
PROXY_IP="$(getent ahostsv4 "${PROXY_HOST}" | awk 'NR==1 {print $1}')"
|
|
||||||
if [ -z "${PROXY_IP}" ]; then
|
|
||||||
echo "failed to resolve proxy host: ${PROXY_HOST}" >&2
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
cat >/tmp/redsocks.conf <<EOF
|
|
||||||
base {
|
|
||||||
log_debug = off;
|
|
||||||
log_info = on;
|
|
||||||
daemon = off;
|
|
||||||
redirector = iptables;
|
|
||||||
}
|
|
||||||
|
|
||||||
redsocks {
|
|
||||||
local_ip = 0.0.0.0;
|
|
||||||
local_port = ${REDSOCKS_PORT};
|
|
||||||
ip = ${PROXY_IP};
|
|
||||||
port = ${PROXY_PORT};
|
|
||||||
type = socks5;
|
|
||||||
EOF
|
|
||||||
|
|
||||||
if [ -n "${PROXY_USER}" ]; then
|
|
||||||
printf ' login = "%s";\n' "${PROXY_USER}" >>/tmp/redsocks.conf
|
|
||||||
fi
|
|
||||||
if [ -n "${PROXY_PASS}" ]; then
|
|
||||||
printf ' password = "%s";\n' "${PROXY_PASS}" >>/tmp/redsocks.conf
|
|
||||||
fi
|
|
||||||
|
|
||||||
cat >>/tmp/redsocks.conf <<EOF
|
|
||||||
}
|
|
||||||
EOF
|
|
||||||
|
|
||||||
redsocks -c /tmp/redsocks.conf >/tmp/redsocks.log 2>&1 &
|
|
||||||
REDSOCKS_PID="$!"
|
|
||||||
trap 'kill "${REDSOCKS_PID}" >/dev/null 2>&1 || true' EXIT
|
|
||||||
|
|
||||||
sleep 1
|
|
||||||
|
|
||||||
iptables -t nat -N REDSOCKS 2>/dev/null || true
|
|
||||||
iptables -t nat -F REDSOCKS
|
|
||||||
iptables -t nat -A REDSOCKS -d 127.0.0.0/8 -j RETURN
|
|
||||||
iptables -t nat -A REDSOCKS -d 127.0.0.11/32 -j RETURN
|
|
||||||
iptables -t nat -A REDSOCKS -d "${PROXY_IP}/32" -j RETURN
|
|
||||||
iptables -t nat -A REDSOCKS -p tcp --dport "${CONTROL_PORT}" -j RETURN
|
|
||||||
iptables -t nat -A REDSOCKS -p tcp -j REDIRECT --to-ports "${REDSOCKS_PORT}"
|
|
||||||
iptables -t nat -D OUTPUT -p tcp -j REDSOCKS 2>/dev/null || true
|
|
||||||
iptables -t nat -A OUTPUT -p tcp -j REDSOCKS
|
|
||||||
|
|
||||||
touch "${NETWORK_READY_FILE}"
|
|
||||||
|
|
||||||
exec gosu sub2api /app/lsworker
|
|
||||||
@ -1,52 +0,0 @@
|
|||||||
ARG GOLANG_IMAGE=golang:1.26.1-alpine
|
|
||||||
ARG DEBIAN_IMAGE=debian:bookworm-slim
|
|
||||||
|
|
||||||
FROM ${GOLANG_IMAGE} AS builder
|
|
||||||
|
|
||||||
WORKDIR /app/backend
|
|
||||||
RUN apk add --no-cache git ca-certificates tzdata
|
|
||||||
|
|
||||||
COPY backend/go.mod backend/go.sum ./
|
|
||||||
RUN go mod download
|
|
||||||
|
|
||||||
COPY backend/ ./
|
|
||||||
RUN CGO_ENABLED=0 GOOS=linux go build -trimpath -ldflags="-s -w" -o /app/lsworker ./cmd/lsworker
|
|
||||||
|
|
||||||
FROM ${DEBIAN_IMAGE}
|
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
|
||||||
ca-certificates \
|
|
||||||
curl \
|
|
||||||
gosu \
|
|
||||||
iproute2 \
|
|
||||||
iptables \
|
|
||||||
redsocks \
|
|
||||||
tzdata \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
RUN groupadd -g 1000 sub2api && \
|
|
||||||
useradd -u 1000 -g sub2api -m -s /bin/sh sub2api
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
COPY --from=builder /app/lsworker /app/lsworker
|
|
||||||
COPY deploy/ls-bin/language_server_linux_* /tmp/ls-bin/
|
|
||||||
COPY deploy/ls-bin/cert.pem /app/ls/extensions/antigravity/dist/languageServer/
|
|
||||||
|
|
||||||
ARG TARGETARCH
|
|
||||||
RUN mkdir -p /app/ls/extensions/antigravity/bin /run/lsworker && \
|
|
||||||
if [ "${TARGETARCH:-amd64}" = "arm64" ]; then \
|
|
||||||
cp /tmp/ls-bin/language_server_linux_arm /app/ls/extensions/antigravity/bin/language_server_linux_arm; \
|
|
||||||
else \
|
|
||||||
cp /tmp/ls-bin/language_server_linux_x64 /app/ls/extensions/antigravity/bin/language_server_linux_x64; \
|
|
||||||
fi && \
|
|
||||||
chmod +x /app/lsworker /app/ls/extensions/antigravity/bin/language_server_linux_* && \
|
|
||||||
chown -R sub2api:sub2api /app /run/lsworker && \
|
|
||||||
rm -rf /tmp/ls-bin
|
|
||||||
|
|
||||||
COPY deploy/lsworker-entrypoint.sh /app/lsworker-entrypoint.sh
|
|
||||||
RUN chmod +x /app/lsworker-entrypoint.sh
|
|
||||||
|
|
||||||
EXPOSE 18081
|
|
||||||
|
|
||||||
ENTRYPOINT ["/app/lsworker-entrypoint.sh"]
|
|
||||||
Loading…
x
Reference in New Issue
Block a user