sub2api/backend/internal/pkg/lspool/worker_manager.go
win b856586412
Some checks failed
CI / test (push) Failing after 16m30s
CI / golangci-lint (push) Failing after 4s
Security Scan / backend-security (push) Failing after 1m35s
Security Scan / frontend-security (push) Failing after 1m31s
修复h1
2026-04-01 01:35:49 +08:00

681 lines
19 KiB
Go

package lspool
import (
"bytes"
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/network"
"github.com/docker/docker/client"
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
)
const (
lsWorkerManagedByLabel = "managed-by"
lsWorkerManagedByValue = "sub2api"
lsWorkerAccountLabel = "account_id"
lsWorkerProxyHashLabel = "proxy_hash"
lsWorkerImageTagLabel = "image_tag"
lsWorkerControlPort = 18081
)
type workerManagerConfig struct {
Image string
Network string
DockerSocket string
IdleTTL time.Duration
MaxActive int
StartupTimeout time.Duration
RequestTimeout time.Duration
}
type dockerClient interface {
ContainerList(ctx context.Context, options container.ListOptions) ([]container.Summary, error)
ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *ocispec.Platform, containerName string) (container.CreateResponse, error)
ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error
ContainerInspect(ctx context.Context, containerID string) (container.InspectResponse, error)
ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error
ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error
Close() error
}
type workerManager struct {
cfg workerManagerConfig
docker dockerClient
http *http.Client
mu sync.Mutex
workers map[string]*workerHandle
state map[string]*workerAccountState
ctx context.Context
cancel context.CancelFunc
logger *slog.Logger
}
type workerHandle struct {
Key string
AccountID string
ProxyURL string
ProxyHash string
ContainerID string
Container string
Address string
AuthToken string
LastUsed time.Time
LastStateSHA string
}
type workerAccountState struct {
HasToken bool `json:"has_token"`
AccessToken string `json:"access_token,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
HasModelCredits bool `json:"has_model_credits"`
UseAICredits bool `json:"use_ai_credits"`
AvailableCredits *int32 `json:"available_credits,omitempty"`
MinimumCreditAmount *int32 `json:"minimum_credit_amount,omitempty"`
}
func NewWorkerManagerFromConfig(cfg *config.Config) (Backend, error) {
if cfg == nil {
return nil, fmt.Errorf("config is nil")
}
managerCfg := workerManagerConfig{
Image: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.Image),
Network: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.Network),
DockerSocket: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.DockerSocket),
IdleTTL: cfg.Gateway.AntigravityLSWorker.IdleTTL,
MaxActive: cfg.Gateway.AntigravityLSWorker.MaxActive,
StartupTimeout: cfg.Gateway.AntigravityLSWorker.StartupTimeout,
RequestTimeout: cfg.Gateway.AntigravityLSWorker.RequestTimeout,
}
if managerCfg.Image == "" {
managerCfg.Image = "weishaw/sub2api-lsworker:latest"
}
if managerCfg.Network == "" {
managerCfg.Network = "sub2api-network"
}
if managerCfg.DockerSocket == "" {
managerCfg.DockerSocket = "unix:///var/run/docker.sock"
}
if managerCfg.IdleTTL <= 0 {
managerCfg.IdleTTL = 15 * time.Minute
}
if managerCfg.MaxActive < 1 {
managerCfg.MaxActive = 50
}
if managerCfg.StartupTimeout <= 0 {
managerCfg.StartupTimeout = 45 * time.Second
}
if managerCfg.RequestTimeout <= 0 {
managerCfg.RequestTimeout = 60 * time.Second
}
opts := []client.Opt{client.WithAPIVersionNegotiation()}
if managerCfg.DockerSocket != "" {
opts = append(opts, client.WithHost(managerCfg.DockerSocket))
} else {
opts = append(opts, client.FromEnv)
}
dockerClient, err := client.NewClientWithOpts(opts...)
if err != nil {
return nil, fmt.Errorf("create docker client: %w", err)
}
return newWorkerManager(managerCfg, dockerClient)
}
func newWorkerManager(cfg workerManagerConfig, docker dockerClient) (*workerManager, error) {
ctx, cancel := context.WithCancel(context.Background())
mgr := &workerManager{
cfg: cfg,
docker: docker,
http: &http.Client{
Timeout: cfg.RequestTimeout,
Transport: &http.Transport{
Proxy: nil,
DialContext: (&net.Dialer{
Timeout: 5 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
MaxIdleConnsPerHost: 8,
},
},
workers: make(map[string]*workerHandle),
state: make(map[string]*workerAccountState),
ctx: ctx,
cancel: cancel,
logger: slog.Default().With("component", "lspool-worker-manager"),
}
if err := mgr.reconcileManagedContainers(ctx); err != nil {
cancel()
_ = docker.Close()
return nil, err
}
go mgr.cleanupLoop()
return mgr, nil
}
func (m *workerManager) Close() {
m.cancel()
m.mu.Lock()
workers := make([]*workerHandle, 0, len(m.workers))
for _, handle := range m.workers {
workers = append(workers, handle)
}
m.workers = make(map[string]*workerHandle)
m.mu.Unlock()
for _, handle := range workers {
m.removeWorkerContainer(context.Background(), handle)
}
if m.docker != nil {
_ = m.docker.Close()
}
}
func (m *workerManager) Stats() map[string]any {
m.mu.Lock()
defer m.mu.Unlock()
return map[string]any{
"accounts": len(m.state),
"total": len(m.workers),
"active": len(m.workers),
}
}
func (m *workerManager) SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time) {
m.mu.Lock()
defer m.mu.Unlock()
state := m.ensureStateLocked(accountID)
state.HasToken = true
state.AccessToken = accessToken
state.RefreshToken = refreshToken
if expiresAt.IsZero() {
state.ExpiresAt = nil
} else {
ts := expiresAt.UTC()
state.ExpiresAt = &ts
}
}
func (m *workerManager) SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32) {
m.mu.Lock()
defer m.mu.Unlock()
state := m.ensureStateLocked(accountID)
state.HasModelCredits = true
state.UseAICredits = useAICredits
state.AvailableCredits = cloneInt32Ptr(availableCredits)
state.MinimumCreditAmount = cloneInt32Ptr(minimumCreditAmountForUsage)
}
func (m *workerManager) GetOrCreate(accountID, routingKey string, proxyURL ...string) (*Instance, error) {
rawProxy := ""
if len(proxyURL) > 0 {
rawProxy = proxyURL[0]
}
normalizedProxy, parsedProxy, err := resolveWorkerProxy(rawProxy)
if err != nil {
return nil, err
}
if parsedProxy == nil {
return nil, fmt.Errorf("ls worker requires a socks5/socks5h proxy for account %s", accountID)
}
replica := replicaSlotIndex(routingKey, parseLSReplicaCount())
proxyHash := proxyHash(normalizedProxy)
workerKey := buildWorkerKey(accountID, proxyHash)
m.mu.Lock()
state := cloneWorkerAccountState(m.state[accountID])
if state == nil || !state.HasToken || strings.TrimSpace(state.AccessToken) == "" {
m.mu.Unlock()
return nil, fmt.Errorf("ls worker missing access token for account %s", accountID)
}
handle := m.workers[workerKey]
if handle == nil {
if len(m.workers) >= m.cfg.MaxActive {
m.mu.Unlock()
return nil, fmt.Errorf("ls worker limit reached (%d active)", m.cfg.MaxActive)
}
handle, err = m.createWorkerLocked(accountID, normalizedProxy, proxyHash, parsedProxy)
if err != nil {
m.mu.Unlock()
return nil, err
}
m.workers[workerKey] = handle
}
handle.LastUsed = time.Now()
m.mu.Unlock()
if err := m.waitForWorkerHealthy(handle); err != nil {
return nil, err
}
if err := m.syncWorkerState(handle, state); err != nil {
return nil, err
}
if err := m.waitForWorkerReady(handle, routingKey); err != nil {
return nil, err
}
inst := &Instance{
AccountID: accountID,
Replica: replica,
Address: handle.Address,
client: m.http,
healthy: true,
lastUsed: time.Now(),
modelMapReady: 1,
remote: true,
workerToken: handle.AuthToken,
routingKey: routingKey,
}
return inst, nil
}
func (m *workerManager) cleanupLoop() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-m.ctx.Done():
return
case <-ticker.C:
m.collectIdleWorkers()
}
}
}
func (m *workerManager) collectIdleWorkers() {
now := time.Now()
var expired []*workerHandle
m.mu.Lock()
for key, handle := range m.workers {
if handle == nil {
delete(m.workers, key)
continue
}
if now.Sub(handle.LastUsed) <= m.cfg.IdleTTL {
continue
}
expired = append(expired, handle)
delete(m.workers, key)
}
m.mu.Unlock()
for _, handle := range expired {
m.removeWorkerContainer(context.Background(), handle)
}
}
func (m *workerManager) reconcileManagedContainers(ctx context.Context) error {
args := filters.NewArgs()
args.Add("label", fmt.Sprintf("%s=%s", lsWorkerManagedByLabel, lsWorkerManagedByValue))
containers, err := m.docker.ContainerList(ctx, container.ListOptions{
All: true,
Filters: args,
})
if err != nil {
return fmt.Errorf("list managed ls workers: %w", err)
}
for _, summary := range containers {
handle := &workerHandle{
ContainerID: summary.ID,
Container: strings.TrimPrefix(firstContainerName(summary.Names), "/"),
}
if err := m.removeWorkerContainer(ctx, handle); err != nil {
return err
}
}
return nil
}
func (m *workerManager) createWorkerLocked(accountID, proxyURL, proxyHash string, parsedProxy *url.URL) (*workerHandle, error) {
containerName := fmt.Sprintf("sub2api-ls-%s-%s", accountID, proxyHash[:8])
authToken := generateUUID()
proxyHost := parsedProxy.Hostname()
proxyPort := parsedProxy.Port()
if proxyPort == "" {
proxyPort = "1080"
}
proxyUser := parsedProxy.User.Username()
proxyPass, _ := parsedProxy.User.Password()
labels := map[string]string{
lsWorkerManagedByLabel: lsWorkerManagedByValue,
lsWorkerAccountLabel: accountID,
lsWorkerProxyHashLabel: proxyHash,
lsWorkerImageTagLabel: m.cfg.Image,
}
env := []string{
"ANTIGRAVITY_APP_ROOT=/app/ls",
fmt.Sprintf("LSWORKER_ACCOUNT_ID=%s", accountID),
fmt.Sprintf("LSWORKER_AUTH_TOKEN=%s", authToken),
fmt.Sprintf("LSWORKER_LISTEN_ADDR=0.0.0.0:%d", lsWorkerControlPort),
fmt.Sprintf("LSWORKER_NETWORK_READY_FILE=%s", "/run/lsworker/network-ready"),
fmt.Sprintf("LSWORKER_PROXY_URL=%s", proxyURL),
fmt.Sprintf("LSWORKER_PROXY_HOST=%s", proxyHost),
fmt.Sprintf("LSWORKER_PROXY_PORT=%s", proxyPort),
fmt.Sprintf("LSWORKER_PROXY_USER=%s", proxyUser),
fmt.Sprintf("LSWORKER_PROXY_PASS=%s", proxyPass),
fmt.Sprintf("LSWORKER_CONTROL_PORT=%d", lsWorkerControlPort),
fmt.Sprintf("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT=%d", parseLSReplicaCount()),
}
if tz := strings.TrimSpace(os.Getenv("TZ")); tz != "" {
env = append(env, "TZ="+tz)
}
createResp, err := m.docker.ContainerCreate(m.ctx, &container.Config{
Image: m.cfg.Image,
Labels: labels,
Env: env,
}, &container.HostConfig{
CapAdd: []string{"NET_ADMIN"},
}, &network.NetworkingConfig{
EndpointsConfig: map[string]*network.EndpointSettings{
m.cfg.Network: {},
},
}, nil, containerName)
if err != nil {
return nil, fmt.Errorf("create ls worker container: %w", err)
}
if err := m.docker.ContainerStart(m.ctx, createResp.ID, container.StartOptions{}); err != nil {
_ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true})
return nil, fmt.Errorf("start ls worker container: %w", err)
}
inspect, err := m.docker.ContainerInspect(m.ctx, createResp.ID)
if err != nil {
_ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true})
return nil, fmt.Errorf("inspect ls worker container: %w", err)
}
address, err := workerAddressFromInspect(inspect, m.cfg.Network)
if err != nil {
_ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true})
return nil, err
}
m.logger.Info("created ls worker",
"account", shortAccountID(accountID),
"container", containerName,
"address", address,
"proxy_hash", proxyHash[:8])
return &workerHandle{
Key: buildWorkerKey(accountID, proxyHash),
AccountID: accountID,
ProxyURL: proxyURL,
ProxyHash: proxyHash,
ContainerID: createResp.ID,
Container: containerName,
Address: address,
AuthToken: authToken,
LastUsed: time.Now(),
}, nil
}
func workerAddressFromInspect(inspect container.InspectResponse, networkName string) (string, error) {
if inspect.NetworkSettings == nil {
return "", fmt.Errorf("ls worker inspect missing network settings")
}
if endpoint, ok := inspect.NetworkSettings.Networks[networkName]; ok && endpoint != nil && strings.TrimSpace(endpoint.IPAddress) != "" {
return net.JoinHostPort(strings.TrimSpace(endpoint.IPAddress), strconv.Itoa(lsWorkerControlPort)), nil
}
for _, endpoint := range inspect.NetworkSettings.Networks {
if endpoint != nil && strings.TrimSpace(endpoint.IPAddress) != "" {
return net.JoinHostPort(strings.TrimSpace(endpoint.IPAddress), strconv.Itoa(lsWorkerControlPort)), nil
}
}
return "", fmt.Errorf("ls worker missing IP address on network %s", networkName)
}
func firstContainerName(names []string) string {
if len(names) == 0 {
return ""
}
return names[0]
}
func (m *workerManager) waitForWorkerHealthy(handle *workerHandle) error {
ctx, cancel := context.WithTimeout(context.Background(), m.cfg.StartupTimeout)
defer cancel()
for {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, workerURL(handle, "/healthz", nil), nil)
if err != nil {
return err
}
req.Header.Set("X-Worker-Token", handle.AuthToken)
resp, err := m.http.Do(req)
if err == nil {
_ = resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return nil
}
}
select {
case <-ctx.Done():
return fmt.Errorf("worker %s failed health check: %w", handle.Container, ctx.Err())
case <-time.After(500 * time.Millisecond):
}
}
}
func (m *workerManager) waitForWorkerReady(handle *workerHandle, routingKey string) error {
ctx, cancel := context.WithTimeout(context.Background(), m.cfg.StartupTimeout)
defer cancel()
values := url.Values{}
if strings.TrimSpace(routingKey) != "" {
values.Set("routing_key", routingKey)
}
var (
lastStatus int
lastBody string
)
for {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, workerURL(handle, "/readyz", values), nil)
if err != nil {
return err
}
req.Header.Set("X-Worker-Token", handle.AuthToken)
resp, err := m.http.Do(req)
if err == nil {
body, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close()
lastStatus = resp.StatusCode
lastBody = truncate(string(body), 200)
if resp.StatusCode == http.StatusOK {
return nil
}
if isWorkerModelMappingUnavailable(resp.StatusCode, lastBody) {
return fmt.Errorf("%w: worker %s %s", errLSModelMapDenied, handle.Container, strings.TrimSpace(lastBody))
}
if len(body) > 0 && shouldWarnWorkerNotReady(resp.StatusCode, lastBody) {
m.logger.Warn("ls worker not ready yet", "container", handle.Container, "status", resp.StatusCode, "body", truncate(string(body), 200))
}
}
select {
case <-ctx.Done():
if lastStatus > 0 || lastBody != "" {
return fmt.Errorf("worker %s not ready for routing key %q (last_status=%d last_body=%q): %w", handle.Container, routingKey, lastStatus, lastBody, ctx.Err())
}
return fmt.Errorf("worker %s not ready for routing key %q: %w", handle.Container, routingKey, ctx.Err())
case <-time.After(500 * time.Millisecond):
}
}
}
func shouldWarnWorkerNotReady(status int, body string) bool {
if status == http.StatusServiceUnavailable {
normalized := strings.ToLower(strings.TrimSpace(body))
if strings.Contains(normalized, "model mapping not ready") {
return false
}
}
return true
}
func isWorkerModelMappingUnavailable(status int, body string) bool {
if status != http.StatusServiceUnavailable {
return false
}
normalized := strings.ToLower(strings.TrimSpace(body))
return strings.Contains(normalized, errLSModelMapDenied.Error())
}
func (m *workerManager) syncWorkerState(handle *workerHandle, state *workerAccountState) error {
if state == nil {
return fmt.Errorf("ls worker state is nil")
}
body, err := json.Marshal(state)
if err != nil {
return fmt.Errorf("marshal worker state: %w", err)
}
sum := fmt.Sprintf("%x", sha256.Sum256(body))
if handle.LastStateSHA == sum {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), m.cfg.RequestTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, workerURL(handle, "/account/state", nil), bytes.NewReader(body))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Worker-Token", handle.AuthToken)
resp, err := m.http.Do(req)
if err != nil {
return fmt.Errorf("sync worker state: %w", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("sync worker state HTTP %d: %s", resp.StatusCode, truncate(string(respBody), 200))
}
handle.LastStateSHA = sum
return nil
}
func workerURL(handle *workerHandle, path string, values url.Values) string {
base := url.URL{
Scheme: "http",
Host: handle.Address,
Path: path,
}
if values != nil {
base.RawQuery = values.Encode()
}
return base.String()
}
func (m *workerManager) removeWorkerContainer(ctx context.Context, handle *workerHandle) error {
if handle == nil || strings.TrimSpace(handle.ContainerID) == "" {
return nil
}
timeout := 5
_ = m.docker.ContainerStop(ctx, handle.ContainerID, container.StopOptions{Timeout: &timeout})
if err := m.docker.ContainerRemove(ctx, handle.ContainerID, container.RemoveOptions{Force: true}); err != nil {
return fmt.Errorf("remove ls worker container %s: %w", handle.ContainerID, err)
}
return nil
}
func (m *workerManager) ensureStateLocked(accountID string) *workerAccountState {
state := m.state[accountID]
if state == nil {
state = &workerAccountState{}
m.state[accountID] = state
}
return state
}
func resolveWorkerProxy(proxyURL string) (string, *url.URL, error) {
resolved := resolveLSProxy(proxyURL)
normalized, parsed, err := proxyurl.Parse(resolved)
if err != nil {
return "", nil, err
}
if parsed == nil {
return "", nil, nil
}
switch strings.ToLower(parsed.Scheme) {
case "socks5", "socks5h":
return normalized, parsed, nil
default:
return "", nil, fmt.Errorf("ls worker only supports socks5/socks5h proxies, got %s", parsed.Scheme)
}
}
func proxyHash(proxyURL string) string {
if strings.TrimSpace(proxyURL) == "" {
return "direct"
}
sum := sha256.Sum256([]byte(strings.TrimSpace(proxyURL)))
return fmt.Sprintf("%x", sum[:])
}
func buildWorkerKey(accountID, proxyHash string) string {
return accountID + ":" + proxyHash
}
func cloneInt32Ptr(v *int32) *int32 {
if v == nil {
return nil
}
cp := *v
return &cp
}
func cloneWorkerAccountState(state *workerAccountState) *workerAccountState {
if state == nil {
return nil
}
cp := *state
cp.AvailableCredits = cloneInt32Ptr(state.AvailableCredits)
cp.MinimumCreditAmount = cloneInt32Ptr(state.MinimumCreditAmount)
if state.ExpiresAt != nil {
ts := *state.ExpiresAt
cp.ExpiresAt = &ts
}
return &cp
}