package lspool import ( "bytes" "context" "crypto/sha256" "encoding/json" "fmt" "io" "log/slog" "net" "net/http" "net/url" "os" "strconv" "strings" "sync" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/filters" "github.com/docker/docker/api/types/network" "github.com/docker/docker/client" ocispec "github.com/opencontainers/image-spec/specs-go/v1" ) const ( lsWorkerManagedByLabel = "managed-by" lsWorkerManagedByValue = "sub2api" lsWorkerAccountLabel = "account_id" lsWorkerProxyHashLabel = "proxy_hash" lsWorkerImageTagLabel = "image_tag" lsWorkerControlPort = 18081 ) type workerManagerConfig struct { Image string Network string DockerSocket string IdleTTL time.Duration MaxActive int StartupTimeout time.Duration RequestTimeout time.Duration } type dockerClient interface { ContainerList(ctx context.Context, options container.ListOptions) ([]container.Summary, error) ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *ocispec.Platform, containerName string) (container.CreateResponse, error) ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error ContainerInspect(ctx context.Context, containerID string) (container.InspectResponse, error) ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error Close() error } type workerManager struct { cfg workerManagerConfig docker dockerClient http *http.Client mu sync.Mutex workers map[string]*workerHandle state map[string]*workerAccountState ctx context.Context cancel context.CancelFunc logger *slog.Logger } type workerHandle struct { Key string AccountID string ProxyURL string ProxyHash string ContainerID string Container string Address string AuthToken string LastUsed time.Time LastStateSHA string } type workerAccountState struct { HasToken bool `json:"has_token"` AccessToken string `json:"access_token,omitempty"` RefreshToken string `json:"refresh_token,omitempty"` ExpiresAt *time.Time `json:"expires_at,omitempty"` HasModelCredits bool `json:"has_model_credits"` UseAICredits bool `json:"use_ai_credits"` AvailableCredits *int32 `json:"available_credits,omitempty"` MinimumCreditAmount *int32 `json:"minimum_credit_amount,omitempty"` } func NewWorkerManagerFromConfig(cfg *config.Config) (Backend, error) { if cfg == nil { return nil, fmt.Errorf("config is nil") } managerCfg := workerManagerConfig{ Image: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.Image), Network: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.Network), DockerSocket: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.DockerSocket), IdleTTL: cfg.Gateway.AntigravityLSWorker.IdleTTL, MaxActive: cfg.Gateway.AntigravityLSWorker.MaxActive, StartupTimeout: cfg.Gateway.AntigravityLSWorker.StartupTimeout, RequestTimeout: cfg.Gateway.AntigravityLSWorker.RequestTimeout, } if managerCfg.Image == "" { managerCfg.Image = "weishaw/sub2api-lsworker:latest" } if managerCfg.Network == "" { managerCfg.Network = "sub2api-network" } if managerCfg.DockerSocket == "" { managerCfg.DockerSocket = "unix:///var/run/docker.sock" } if managerCfg.IdleTTL <= 0 { managerCfg.IdleTTL = 15 * time.Minute } if managerCfg.MaxActive < 1 { managerCfg.MaxActive = 50 } if managerCfg.StartupTimeout <= 0 { managerCfg.StartupTimeout = 45 * time.Second } if managerCfg.RequestTimeout <= 0 { managerCfg.RequestTimeout = 60 * time.Second } opts := []client.Opt{client.WithAPIVersionNegotiation()} if managerCfg.DockerSocket != "" { opts = append(opts, client.WithHost(managerCfg.DockerSocket)) } else { opts = append(opts, client.FromEnv) } dockerClient, err := client.NewClientWithOpts(opts...) if err != nil { return nil, fmt.Errorf("create docker client: %w", err) } return newWorkerManager(managerCfg, dockerClient) } func newWorkerManager(cfg workerManagerConfig, docker dockerClient) (*workerManager, error) { ctx, cancel := context.WithCancel(context.Background()) mgr := &workerManager{ cfg: cfg, docker: docker, http: &http.Client{ Timeout: cfg.RequestTimeout, Transport: &http.Transport{ Proxy: nil, DialContext: (&net.Dialer{ Timeout: 5 * time.Second, KeepAlive: 30 * time.Second, }).DialContext, MaxIdleConnsPerHost: 8, }, }, workers: make(map[string]*workerHandle), state: make(map[string]*workerAccountState), ctx: ctx, cancel: cancel, logger: slog.Default().With("component", "lspool-worker-manager"), } if err := mgr.reconcileManagedContainers(ctx); err != nil { cancel() _ = docker.Close() return nil, err } go mgr.cleanupLoop() return mgr, nil } func (m *workerManager) Close() { m.cancel() m.mu.Lock() workers := make([]*workerHandle, 0, len(m.workers)) for _, handle := range m.workers { workers = append(workers, handle) } m.workers = make(map[string]*workerHandle) m.mu.Unlock() for _, handle := range workers { m.removeWorkerContainer(context.Background(), handle) } if m.docker != nil { _ = m.docker.Close() } } func (m *workerManager) Stats() map[string]any { m.mu.Lock() defer m.mu.Unlock() return map[string]any{ "accounts": len(m.state), "total": len(m.workers), "active": len(m.workers), } } func (m *workerManager) SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time) { m.mu.Lock() defer m.mu.Unlock() state := m.ensureStateLocked(accountID) state.HasToken = true state.AccessToken = accessToken state.RefreshToken = refreshToken if expiresAt.IsZero() { state.ExpiresAt = nil } else { ts := expiresAt.UTC() state.ExpiresAt = &ts } } func (m *workerManager) SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32) { m.mu.Lock() defer m.mu.Unlock() state := m.ensureStateLocked(accountID) state.HasModelCredits = true state.UseAICredits = useAICredits state.AvailableCredits = cloneInt32Ptr(availableCredits) state.MinimumCreditAmount = cloneInt32Ptr(minimumCreditAmountForUsage) } func (m *workerManager) GetOrCreate(accountID, routingKey string, proxyURL ...string) (*Instance, error) { rawProxy := "" if len(proxyURL) > 0 { rawProxy = proxyURL[0] } normalizedProxy, parsedProxy, err := resolveWorkerProxy(rawProxy) if err != nil { return nil, err } if parsedProxy == nil { return nil, fmt.Errorf("ls worker requires a socks5/socks5h proxy for account %s", accountID) } replica := replicaSlotIndex(routingKey, parseLSReplicaCount()) proxyHash := proxyHash(normalizedProxy) workerKey := buildWorkerKey(accountID, proxyHash) m.mu.Lock() state := cloneWorkerAccountState(m.state[accountID]) if state == nil || !state.HasToken || strings.TrimSpace(state.AccessToken) == "" { m.mu.Unlock() return nil, fmt.Errorf("ls worker missing access token for account %s", accountID) } handle := m.workers[workerKey] if handle == nil { if len(m.workers) >= m.cfg.MaxActive { m.mu.Unlock() return nil, fmt.Errorf("ls worker limit reached (%d active)", m.cfg.MaxActive) } handle, err = m.createWorkerLocked(accountID, normalizedProxy, proxyHash, parsedProxy) if err != nil { m.mu.Unlock() return nil, err } m.workers[workerKey] = handle } handle.LastUsed = time.Now() m.mu.Unlock() if err := m.waitForWorkerHealthy(handle); err != nil { return nil, err } if err := m.syncWorkerState(handle, state); err != nil { return nil, err } if err := m.waitForWorkerReady(handle, routingKey); err != nil { return nil, err } inst := &Instance{ AccountID: accountID, Replica: replica, Address: handle.Address, client: m.http, healthy: true, lastUsed: time.Now(), modelMapReady: 1, remote: true, workerToken: handle.AuthToken, routingKey: routingKey, } return inst, nil } func (m *workerManager) cleanupLoop() { ticker := time.NewTicker(time.Minute) defer ticker.Stop() for { select { case <-m.ctx.Done(): return case <-ticker.C: m.collectIdleWorkers() } } } func (m *workerManager) collectIdleWorkers() { now := time.Now() var expired []*workerHandle m.mu.Lock() for key, handle := range m.workers { if handle == nil { delete(m.workers, key) continue } if now.Sub(handle.LastUsed) <= m.cfg.IdleTTL { continue } expired = append(expired, handle) delete(m.workers, key) } m.mu.Unlock() for _, handle := range expired { m.removeWorkerContainer(context.Background(), handle) } } func (m *workerManager) reconcileManagedContainers(ctx context.Context) error { args := filters.NewArgs() args.Add("label", fmt.Sprintf("%s=%s", lsWorkerManagedByLabel, lsWorkerManagedByValue)) containers, err := m.docker.ContainerList(ctx, container.ListOptions{ All: true, Filters: args, }) if err != nil { return fmt.Errorf("list managed ls workers: %w", err) } for _, summary := range containers { handle := &workerHandle{ ContainerID: summary.ID, Container: strings.TrimPrefix(firstContainerName(summary.Names), "/"), } if err := m.removeWorkerContainer(ctx, handle); err != nil { return err } } return nil } func (m *workerManager) createWorkerLocked(accountID, proxyURL, proxyHash string, parsedProxy *url.URL) (*workerHandle, error) { containerName := fmt.Sprintf("sub2api-ls-%s-%s", accountID, proxyHash[:8]) authToken := generateUUID() proxyHost := parsedProxy.Hostname() proxyPort := parsedProxy.Port() if proxyPort == "" { proxyPort = "1080" } proxyUser := parsedProxy.User.Username() proxyPass, _ := parsedProxy.User.Password() labels := map[string]string{ lsWorkerManagedByLabel: lsWorkerManagedByValue, lsWorkerAccountLabel: accountID, lsWorkerProxyHashLabel: proxyHash, lsWorkerImageTagLabel: m.cfg.Image, } env := []string{ "ANTIGRAVITY_APP_ROOT=/app/ls", fmt.Sprintf("LSWORKER_ACCOUNT_ID=%s", accountID), fmt.Sprintf("LSWORKER_AUTH_TOKEN=%s", authToken), fmt.Sprintf("LSWORKER_LISTEN_ADDR=0.0.0.0:%d", lsWorkerControlPort), fmt.Sprintf("LSWORKER_NETWORK_READY_FILE=%s", "/run/lsworker/network-ready"), fmt.Sprintf("LSWORKER_PROXY_URL=%s", proxyURL), fmt.Sprintf("LSWORKER_PROXY_HOST=%s", proxyHost), fmt.Sprintf("LSWORKER_PROXY_PORT=%s", proxyPort), fmt.Sprintf("LSWORKER_PROXY_USER=%s", proxyUser), fmt.Sprintf("LSWORKER_PROXY_PASS=%s", proxyPass), fmt.Sprintf("LSWORKER_CONTROL_PORT=%d", lsWorkerControlPort), fmt.Sprintf("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT=%d", parseLSReplicaCount()), } if tz := strings.TrimSpace(os.Getenv("TZ")); tz != "" { env = append(env, "TZ="+tz) } createResp, err := m.docker.ContainerCreate(m.ctx, &container.Config{ Image: m.cfg.Image, Labels: labels, Env: env, }, &container.HostConfig{ CapAdd: []string{"NET_ADMIN"}, }, &network.NetworkingConfig{ EndpointsConfig: map[string]*network.EndpointSettings{ m.cfg.Network: {}, }, }, nil, containerName) if err != nil { return nil, fmt.Errorf("create ls worker container: %w", err) } if err := m.docker.ContainerStart(m.ctx, createResp.ID, container.StartOptions{}); err != nil { _ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true}) return nil, fmt.Errorf("start ls worker container: %w", err) } inspect, err := m.docker.ContainerInspect(m.ctx, createResp.ID) if err != nil { _ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true}) return nil, fmt.Errorf("inspect ls worker container: %w", err) } address, err := workerAddressFromInspect(inspect, m.cfg.Network) if err != nil { _ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true}) return nil, err } m.logger.Info("created ls worker", "account", shortAccountID(accountID), "container", containerName, "address", address, "proxy_hash", proxyHash[:8]) return &workerHandle{ Key: buildWorkerKey(accountID, proxyHash), AccountID: accountID, ProxyURL: proxyURL, ProxyHash: proxyHash, ContainerID: createResp.ID, Container: containerName, Address: address, AuthToken: authToken, LastUsed: time.Now(), }, nil } func workerAddressFromInspect(inspect container.InspectResponse, networkName string) (string, error) { if inspect.NetworkSettings == nil { return "", fmt.Errorf("ls worker inspect missing network settings") } if endpoint, ok := inspect.NetworkSettings.Networks[networkName]; ok && endpoint != nil && strings.TrimSpace(endpoint.IPAddress) != "" { return net.JoinHostPort(strings.TrimSpace(endpoint.IPAddress), strconv.Itoa(lsWorkerControlPort)), nil } for _, endpoint := range inspect.NetworkSettings.Networks { if endpoint != nil && strings.TrimSpace(endpoint.IPAddress) != "" { return net.JoinHostPort(strings.TrimSpace(endpoint.IPAddress), strconv.Itoa(lsWorkerControlPort)), nil } } return "", fmt.Errorf("ls worker missing IP address on network %s", networkName) } func firstContainerName(names []string) string { if len(names) == 0 { return "" } return names[0] } func (m *workerManager) waitForWorkerHealthy(handle *workerHandle) error { ctx, cancel := context.WithTimeout(context.Background(), m.cfg.StartupTimeout) defer cancel() for { req, err := http.NewRequestWithContext(ctx, http.MethodGet, workerURL(handle, "/healthz", nil), nil) if err != nil { return err } req.Header.Set("X-Worker-Token", handle.AuthToken) resp, err := m.http.Do(req) if err == nil { _ = resp.Body.Close() if resp.StatusCode == http.StatusOK { return nil } } select { case <-ctx.Done(): return fmt.Errorf("worker %s failed health check: %w", handle.Container, ctx.Err()) case <-time.After(500 * time.Millisecond): } } } func (m *workerManager) waitForWorkerReady(handle *workerHandle, routingKey string) error { ctx, cancel := context.WithTimeout(context.Background(), m.cfg.StartupTimeout) defer cancel() values := url.Values{} if strings.TrimSpace(routingKey) != "" { values.Set("routing_key", routingKey) } var ( lastStatus int lastBody string ) for { req, err := http.NewRequestWithContext(ctx, http.MethodGet, workerURL(handle, "/readyz", values), nil) if err != nil { return err } req.Header.Set("X-Worker-Token", handle.AuthToken) resp, err := m.http.Do(req) if err == nil { body, _ := io.ReadAll(resp.Body) _ = resp.Body.Close() lastStatus = resp.StatusCode lastBody = truncate(string(body), 200) if resp.StatusCode == http.StatusOK { return nil } if isWorkerModelMappingUnavailable(resp.StatusCode, lastBody) { return fmt.Errorf("%w: worker %s %s", errLSModelMapDenied, handle.Container, strings.TrimSpace(lastBody)) } if len(body) > 0 && shouldWarnWorkerNotReady(resp.StatusCode, lastBody) { m.logger.Warn("ls worker not ready yet", "container", handle.Container, "status", resp.StatusCode, "body", truncate(string(body), 200)) } } select { case <-ctx.Done(): if lastStatus > 0 || lastBody != "" { return fmt.Errorf("worker %s not ready for routing key %q (last_status=%d last_body=%q): %w", handle.Container, routingKey, lastStatus, lastBody, ctx.Err()) } return fmt.Errorf("worker %s not ready for routing key %q: %w", handle.Container, routingKey, ctx.Err()) case <-time.After(500 * time.Millisecond): } } } func shouldWarnWorkerNotReady(status int, body string) bool { if status == http.StatusServiceUnavailable { normalized := strings.ToLower(strings.TrimSpace(body)) if strings.Contains(normalized, "model mapping not ready") { return false } } return true } func isWorkerModelMappingUnavailable(status int, body string) bool { if status != http.StatusServiceUnavailable { return false } normalized := strings.ToLower(strings.TrimSpace(body)) return strings.Contains(normalized, errLSModelMapDenied.Error()) } func (m *workerManager) syncWorkerState(handle *workerHandle, state *workerAccountState) error { if state == nil { return fmt.Errorf("ls worker state is nil") } body, err := json.Marshal(state) if err != nil { return fmt.Errorf("marshal worker state: %w", err) } sum := fmt.Sprintf("%x", sha256.Sum256(body)) if handle.LastStateSHA == sum { return nil } ctx, cancel := context.WithTimeout(context.Background(), m.cfg.RequestTimeout) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodPost, workerURL(handle, "/account/state", nil), bytes.NewReader(body)) if err != nil { return err } req.Header.Set("Content-Type", "application/json") req.Header.Set("X-Worker-Token", handle.AuthToken) resp, err := m.http.Do(req) if err != nil { return fmt.Errorf("sync worker state: %w", err) } defer resp.Body.Close() respBody, _ := io.ReadAll(resp.Body) if resp.StatusCode != http.StatusOK { return fmt.Errorf("sync worker state HTTP %d: %s", resp.StatusCode, truncate(string(respBody), 200)) } handle.LastStateSHA = sum return nil } func workerURL(handle *workerHandle, path string, values url.Values) string { base := url.URL{ Scheme: "http", Host: handle.Address, Path: path, } if values != nil { base.RawQuery = values.Encode() } return base.String() } func (m *workerManager) removeWorkerContainer(ctx context.Context, handle *workerHandle) error { if handle == nil || strings.TrimSpace(handle.ContainerID) == "" { return nil } timeout := 5 _ = m.docker.ContainerStop(ctx, handle.ContainerID, container.StopOptions{Timeout: &timeout}) if err := m.docker.ContainerRemove(ctx, handle.ContainerID, container.RemoveOptions{Force: true}); err != nil { return fmt.Errorf("remove ls worker container %s: %w", handle.ContainerID, err) } return nil } func (m *workerManager) ensureStateLocked(accountID string) *workerAccountState { state := m.state[accountID] if state == nil { state = &workerAccountState{} m.state[accountID] = state } return state } func resolveWorkerProxy(proxyURL string) (string, *url.URL, error) { resolved := resolveLSProxy(proxyURL) normalized, parsed, err := proxyurl.Parse(resolved) if err != nil { return "", nil, err } if parsed == nil { return "", nil, nil } switch strings.ToLower(parsed.Scheme) { case "socks5", "socks5h": return normalized, parsed, nil default: return "", nil, fmt.Errorf("ls worker only supports socks5/socks5h proxies, got %s", parsed.Scheme) } } func proxyHash(proxyURL string) string { if strings.TrimSpace(proxyURL) == "" { return "direct" } sum := sha256.Sum256([]byte(strings.TrimSpace(proxyURL))) return fmt.Sprintf("%x", sum[:]) } func buildWorkerKey(accountID, proxyHash string) string { return accountID + ":" + proxyHash } func cloneInt32Ptr(v *int32) *int32 { if v == nil { return nil } cp := *v return &cp } func cloneWorkerAccountState(state *workerAccountState) *workerAccountState { if state == nil { return nil } cp := *state cp.AvailableCredits = cloneInt32Ptr(state.AvailableCredits) cp.MinimumCreditAmount = cloneInt32Ptr(state.MinimumCreditAmount) if state.ExpiresAt != nil { ts := *state.ExpiresAt cp.ExpiresAt = &ts } return &cp }