sub2api/backend/internal/pkg/lspool/worker_server.go
win b856586412
Some checks failed
CI / test (push) Failing after 16m30s
CI / golangci-lint (push) Failing after 4s
Security Scan / backend-security (push) Failing after 1m35s
Security Scan / frontend-security (push) Failing after 1m31s
修复h1
2026-04-01 01:35:49 +08:00

375 lines
9.5 KiB
Go

package lspool
import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"strconv"
"strings"
"sync"
"time"
)
type WorkerServerConfig struct {
AccountID string
AuthToken string
ListenAddr string
AppRoot string
NetworkReadyFile string
MaxIdleTime time.Duration
HealthInterval time.Duration
}
type WorkerServer struct {
cfg WorkerServerConfig
pool *Pool
logger *slog.Logger
mu sync.RWMutex
state workerAccountState
}
func NewWorkerServer(cfg WorkerServerConfig) (*WorkerServer, error) {
if strings.TrimSpace(cfg.AccountID) == "" {
return nil, fmt.Errorf("worker account id is required")
}
if strings.TrimSpace(cfg.AuthToken) == "" {
return nil, fmt.Errorf("worker auth token is required")
}
if strings.TrimSpace(cfg.ListenAddr) == "" {
cfg.ListenAddr = fmt.Sprintf("0.0.0.0:%d", lsWorkerControlPort)
}
if strings.TrimSpace(cfg.AppRoot) == "" {
cfg.AppRoot = "/app/ls"
}
if cfg.MaxIdleTime <= 0 {
cfg.MaxIdleTime = 15 * time.Minute
}
if cfg.HealthInterval <= 0 {
cfg.HealthInterval = 30 * time.Second
}
poolCfg := DefaultConfig()
poolCfg.AppRoot = cfg.AppRoot
poolCfg.MaxIdleTime = cfg.MaxIdleTime
poolCfg.HealthCheckInterval = cfg.HealthInterval
return &WorkerServer{
cfg: cfg,
pool: NewPool(poolCfg),
logger: slog.Default().With("component", "lsworker"),
}, nil
}
func NewWorkerServerFromEnv() (*WorkerServer, error) {
maxIdleTime := 15 * time.Minute
if raw := strings.TrimSpace(os.Getenv("LSWORKER_POOL_MAX_IDLE_TIME")); raw != "" {
if parsed, err := time.ParseDuration(raw); err == nil {
maxIdleTime = parsed
}
}
healthInterval := 30 * time.Second
if raw := strings.TrimSpace(os.Getenv("LSWORKER_POOL_HEALTH_INTERVAL")); raw != "" {
if parsed, err := time.ParseDuration(raw); err == nil {
healthInterval = parsed
}
}
return NewWorkerServer(WorkerServerConfig{
AccountID: strings.TrimSpace(os.Getenv("LSWORKER_ACCOUNT_ID")),
AuthToken: strings.TrimSpace(os.Getenv("LSWORKER_AUTH_TOKEN")),
ListenAddr: strings.TrimSpace(os.Getenv("LSWORKER_LISTEN_ADDR")),
AppRoot: strings.TrimSpace(os.Getenv("ANTIGRAVITY_APP_ROOT")),
NetworkReadyFile: strings.TrimSpace(os.Getenv("LSWORKER_NETWORK_READY_FILE")),
MaxIdleTime: maxIdleTime,
HealthInterval: healthInterval,
})
}
func (s *WorkerServer) Close() {
if s.pool != nil {
s.pool.Close()
}
}
func (s *WorkerServer) Handler() http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/healthz", s.handleHealthz)
mux.HandleFunc("/readyz", s.handleReadyz)
mux.HandleFunc("/account/state", s.handleAccountState)
mux.HandleFunc("/rpc/unary", s.handleRPCUnary)
mux.HandleFunc("/rpc/stream", s.handleRPCStream)
return mux
}
func (s *WorkerServer) handleHealthz(w http.ResponseWriter, r *http.Request) {
if !s.authorize(w, r) {
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}
func (s *WorkerServer) handleReadyz(w http.ResponseWriter, r *http.Request) {
if !s.authorize(w, r) {
return
}
routingKey := strings.TrimSpace(r.URL.Query().Get("routing_key"))
inst, err := s.ensureReady(r.Context(), routingKey)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(fmt.Sprintf("ready replica=%d", inst.Replica)))
}
func (s *WorkerServer) handleAccountState(w http.ResponseWriter, r *http.Request) {
if !s.authorize(w, r) {
return
}
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
defer r.Body.Close()
var payload workerAccountState
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
http.Error(w, "invalid account state payload", http.StatusBadRequest)
return
}
s.mu.Lock()
s.state = *cloneWorkerAccountState(&payload)
s.mu.Unlock()
s.applyState()
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}
func (s *WorkerServer) handleRPCUnary(w http.ResponseWriter, r *http.Request) {
if !s.authorize(w, r) {
return
}
service, method, mode, routingKey, ok := parseRPCRequest(w, r)
if !ok {
return
}
inst, err := s.ensureReady(r.Context(), routingKey)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "read request body failed", http.StatusBadRequest)
return
}
if len(body) == 0 {
body = []byte("{}")
}
var respBody []byte
switch mode {
case "json":
var input any
if err := json.Unmarshal(body, &input); err != nil {
http.Error(w, "invalid json rpc body", http.StatusBadRequest)
return
}
respBody, err = inst.CallUnaryJSON(r.Context(), service, method, input)
case "proto":
respBody, err = inst.CallRPC(r.Context(), service, method, body)
default:
http.Error(w, "unsupported rpc mode", http.StatusBadRequest)
return
}
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write(respBody)
}
func (s *WorkerServer) handleRPCStream(w http.ResponseWriter, r *http.Request) {
if !s.authorize(w, r) {
return
}
service, method, mode, routingKey, ok := parseRPCRequest(w, r)
if !ok {
return
}
inst, err := s.ensureReady(r.Context(), routingKey)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "read request body failed", http.StatusBadRequest)
return
}
var resp *http.Response
switch mode {
case "json":
var input any
if len(body) == 0 {
body = []byte("{}")
}
if err := json.Unmarshal(body, &input); err != nil {
http.Error(w, "invalid json rpc body", http.StatusBadRequest)
return
}
resp, err = inst.StreamRPCJSON(r.Context(), service, method, input)
case "proto":
resp, err = inst.StreamRPC(r.Context(), service, method, body)
default:
http.Error(w, "unsupported rpc mode", http.StatusBadRequest)
return
}
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
defer resp.Body.Close()
for key, values := range resp.Header {
for _, value := range values {
w.Header().Add(key, value)
}
}
w.WriteHeader(resp.StatusCode)
_, _ = io.Copy(w, resp.Body)
}
func (s *WorkerServer) authorize(w http.ResponseWriter, r *http.Request) bool {
if subtleHeaderEqual(r.Header.Get("X-Worker-Token"), s.cfg.AuthToken) {
return true
}
http.Error(w, "unauthorized", http.StatusUnauthorized)
return false
}
func subtleHeaderEqual(left, right string) bool {
if left == "" || right == "" {
return false
}
return left == right
}
func parseRPCRequest(w http.ResponseWriter, r *http.Request) (service, method, mode, routingKey string, ok bool) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return "", "", "", "", false
}
query := r.URL.Query()
service = strings.TrimSpace(query.Get("service"))
method = strings.TrimSpace(query.Get("method"))
mode = strings.ToLower(strings.TrimSpace(query.Get("mode")))
routingKey = strings.TrimSpace(query.Get("routing_key"))
if service == "" || method == "" {
http.Error(w, "missing rpc target", http.StatusBadRequest)
return "", "", "", "", false
}
if mode == "" {
mode = "proto"
}
return service, method, mode, routingKey, true
}
func (s *WorkerServer) ensureReady(ctx context.Context, routingKey string) (*Instance, error) {
if path := strings.TrimSpace(s.cfg.NetworkReadyFile); path != "" {
if _, err := os.Stat(path); err != nil {
return nil, fmt.Errorf("worker network not ready: %w", err)
}
}
s.applyState()
s.mu.RLock()
state := cloneWorkerAccountState(&s.state)
s.mu.RUnlock()
if state == nil || !state.HasToken || strings.TrimSpace(state.AccessToken) == "" {
return nil, fmt.Errorf("worker access token not configured")
}
inst, err := s.pool.GetOrCreate(s.cfg.AccountID, routingKey, "")
if err != nil {
return nil, err
}
if inst.HasModelMappingUnavailable() {
return nil, fmt.Errorf("%w for replica %d: %s", errLSModelMapDenied, inst.Replica, inst.ModelMappingUnavailableReason())
}
if inst.HasModelMappingReady() {
return inst, nil
}
modelCtx, cancel := context.WithTimeout(ctx, lsModelConfigTimeout)
defer cancel()
_ = modelCtx
if !RefreshModelMapping(inst) {
if inst.HasModelMappingUnavailable() {
return nil, fmt.Errorf("%w for replica %d: %s", errLSModelMapDenied, inst.Replica, inst.ModelMappingUnavailableReason())
}
return nil, fmt.Errorf("worker model mapping not ready for replica %d", inst.Replica)
}
return inst, nil
}
func (s *WorkerServer) applyState() {
s.mu.RLock()
state := cloneWorkerAccountState(&s.state)
s.mu.RUnlock()
if state == nil {
return
}
if state.HasToken {
expiresAt := time.Time{}
if state.ExpiresAt != nil {
expiresAt = state.ExpiresAt.UTC()
}
s.pool.SetAccountToken(s.cfg.AccountID, state.AccessToken, state.RefreshToken, expiresAt)
}
if state.HasModelCredits {
s.pool.SetAccountModelCredits(s.cfg.AccountID, state.UseAICredits, state.AvailableCredits, state.MinimumCreditAmount)
}
}
func workerHTTPServer(listenAddr string, handler http.Handler) *http.Server {
return &http.Server{
Addr: listenAddr,
Handler: handler,
ReadHeaderTimeout: 10 * time.Second,
}
}
func workerExitCode(err error) int {
if err == nil {
return 0
}
return 1
}
func parseWorkerControlPort() int {
raw := strings.TrimSpace(os.Getenv("LSWORKER_CONTROL_PORT"))
if raw == "" {
return lsWorkerControlPort
}
port, err := strconv.Atoi(raw)
if err != nil || port < 1 {
return lsWorkerControlPort
}
return port
}