375 lines
9.5 KiB
Go
375 lines
9.5 KiB
Go
package lspool
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
type WorkerServerConfig struct {
|
|
AccountID string
|
|
AuthToken string
|
|
ListenAddr string
|
|
AppRoot string
|
|
NetworkReadyFile string
|
|
MaxIdleTime time.Duration
|
|
HealthInterval time.Duration
|
|
}
|
|
|
|
type WorkerServer struct {
|
|
cfg WorkerServerConfig
|
|
pool *Pool
|
|
logger *slog.Logger
|
|
|
|
mu sync.RWMutex
|
|
state workerAccountState
|
|
}
|
|
|
|
func NewWorkerServer(cfg WorkerServerConfig) (*WorkerServer, error) {
|
|
if strings.TrimSpace(cfg.AccountID) == "" {
|
|
return nil, fmt.Errorf("worker account id is required")
|
|
}
|
|
if strings.TrimSpace(cfg.AuthToken) == "" {
|
|
return nil, fmt.Errorf("worker auth token is required")
|
|
}
|
|
if strings.TrimSpace(cfg.ListenAddr) == "" {
|
|
cfg.ListenAddr = fmt.Sprintf("0.0.0.0:%d", lsWorkerControlPort)
|
|
}
|
|
if strings.TrimSpace(cfg.AppRoot) == "" {
|
|
cfg.AppRoot = "/app/ls"
|
|
}
|
|
if cfg.MaxIdleTime <= 0 {
|
|
cfg.MaxIdleTime = 15 * time.Minute
|
|
}
|
|
if cfg.HealthInterval <= 0 {
|
|
cfg.HealthInterval = 30 * time.Second
|
|
}
|
|
|
|
poolCfg := DefaultConfig()
|
|
poolCfg.AppRoot = cfg.AppRoot
|
|
poolCfg.MaxIdleTime = cfg.MaxIdleTime
|
|
poolCfg.HealthCheckInterval = cfg.HealthInterval
|
|
|
|
return &WorkerServer{
|
|
cfg: cfg,
|
|
pool: NewPool(poolCfg),
|
|
logger: slog.Default().With("component", "lsworker"),
|
|
}, nil
|
|
}
|
|
|
|
func NewWorkerServerFromEnv() (*WorkerServer, error) {
|
|
maxIdleTime := 15 * time.Minute
|
|
if raw := strings.TrimSpace(os.Getenv("LSWORKER_POOL_MAX_IDLE_TIME")); raw != "" {
|
|
if parsed, err := time.ParseDuration(raw); err == nil {
|
|
maxIdleTime = parsed
|
|
}
|
|
}
|
|
healthInterval := 30 * time.Second
|
|
if raw := strings.TrimSpace(os.Getenv("LSWORKER_POOL_HEALTH_INTERVAL")); raw != "" {
|
|
if parsed, err := time.ParseDuration(raw); err == nil {
|
|
healthInterval = parsed
|
|
}
|
|
}
|
|
|
|
return NewWorkerServer(WorkerServerConfig{
|
|
AccountID: strings.TrimSpace(os.Getenv("LSWORKER_ACCOUNT_ID")),
|
|
AuthToken: strings.TrimSpace(os.Getenv("LSWORKER_AUTH_TOKEN")),
|
|
ListenAddr: strings.TrimSpace(os.Getenv("LSWORKER_LISTEN_ADDR")),
|
|
AppRoot: strings.TrimSpace(os.Getenv("ANTIGRAVITY_APP_ROOT")),
|
|
NetworkReadyFile: strings.TrimSpace(os.Getenv("LSWORKER_NETWORK_READY_FILE")),
|
|
MaxIdleTime: maxIdleTime,
|
|
HealthInterval: healthInterval,
|
|
})
|
|
}
|
|
|
|
func (s *WorkerServer) Close() {
|
|
if s.pool != nil {
|
|
s.pool.Close()
|
|
}
|
|
}
|
|
|
|
func (s *WorkerServer) Handler() http.Handler {
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/healthz", s.handleHealthz)
|
|
mux.HandleFunc("/readyz", s.handleReadyz)
|
|
mux.HandleFunc("/account/state", s.handleAccountState)
|
|
mux.HandleFunc("/rpc/unary", s.handleRPCUnary)
|
|
mux.HandleFunc("/rpc/stream", s.handleRPCStream)
|
|
return mux
|
|
}
|
|
|
|
func (s *WorkerServer) handleHealthz(w http.ResponseWriter, r *http.Request) {
|
|
if !s.authorize(w, r) {
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte("ok"))
|
|
}
|
|
|
|
func (s *WorkerServer) handleReadyz(w http.ResponseWriter, r *http.Request) {
|
|
if !s.authorize(w, r) {
|
|
return
|
|
}
|
|
routingKey := strings.TrimSpace(r.URL.Query().Get("routing_key"))
|
|
inst, err := s.ensureReady(r.Context(), routingKey)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte(fmt.Sprintf("ready replica=%d", inst.Replica)))
|
|
}
|
|
|
|
func (s *WorkerServer) handleAccountState(w http.ResponseWriter, r *http.Request) {
|
|
if !s.authorize(w, r) {
|
|
return
|
|
}
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
defer r.Body.Close()
|
|
var payload workerAccountState
|
|
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
|
http.Error(w, "invalid account state payload", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
s.mu.Lock()
|
|
s.state = *cloneWorkerAccountState(&payload)
|
|
s.mu.Unlock()
|
|
s.applyState()
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte("ok"))
|
|
}
|
|
|
|
func (s *WorkerServer) handleRPCUnary(w http.ResponseWriter, r *http.Request) {
|
|
if !s.authorize(w, r) {
|
|
return
|
|
}
|
|
service, method, mode, routingKey, ok := parseRPCRequest(w, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
inst, err := s.ensureReady(r.Context(), routingKey)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
http.Error(w, "read request body failed", http.StatusBadRequest)
|
|
return
|
|
}
|
|
if len(body) == 0 {
|
|
body = []byte("{}")
|
|
}
|
|
|
|
var respBody []byte
|
|
switch mode {
|
|
case "json":
|
|
var input any
|
|
if err := json.Unmarshal(body, &input); err != nil {
|
|
http.Error(w, "invalid json rpc body", http.StatusBadRequest)
|
|
return
|
|
}
|
|
respBody, err = inst.CallUnaryJSON(r.Context(), service, method, input)
|
|
case "proto":
|
|
respBody, err = inst.CallRPC(r.Context(), service, method, body)
|
|
default:
|
|
http.Error(w, "unsupported rpc mode", http.StatusBadRequest)
|
|
return
|
|
}
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadGateway)
|
|
return
|
|
}
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write(respBody)
|
|
}
|
|
|
|
func (s *WorkerServer) handleRPCStream(w http.ResponseWriter, r *http.Request) {
|
|
if !s.authorize(w, r) {
|
|
return
|
|
}
|
|
service, method, mode, routingKey, ok := parseRPCRequest(w, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
inst, err := s.ensureReady(r.Context(), routingKey)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
http.Error(w, "read request body failed", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
var resp *http.Response
|
|
switch mode {
|
|
case "json":
|
|
var input any
|
|
if len(body) == 0 {
|
|
body = []byte("{}")
|
|
}
|
|
if err := json.Unmarshal(body, &input); err != nil {
|
|
http.Error(w, "invalid json rpc body", http.StatusBadRequest)
|
|
return
|
|
}
|
|
resp, err = inst.StreamRPCJSON(r.Context(), service, method, input)
|
|
case "proto":
|
|
resp, err = inst.StreamRPC(r.Context(), service, method, body)
|
|
default:
|
|
http.Error(w, "unsupported rpc mode", http.StatusBadRequest)
|
|
return
|
|
}
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadGateway)
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
for key, values := range resp.Header {
|
|
for _, value := range values {
|
|
w.Header().Add(key, value)
|
|
}
|
|
}
|
|
w.WriteHeader(resp.StatusCode)
|
|
_, _ = io.Copy(w, resp.Body)
|
|
}
|
|
|
|
func (s *WorkerServer) authorize(w http.ResponseWriter, r *http.Request) bool {
|
|
if subtleHeaderEqual(r.Header.Get("X-Worker-Token"), s.cfg.AuthToken) {
|
|
return true
|
|
}
|
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
return false
|
|
}
|
|
|
|
func subtleHeaderEqual(left, right string) bool {
|
|
if left == "" || right == "" {
|
|
return false
|
|
}
|
|
return left == right
|
|
}
|
|
|
|
func parseRPCRequest(w http.ResponseWriter, r *http.Request) (service, method, mode, routingKey string, ok bool) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
|
return "", "", "", "", false
|
|
}
|
|
query := r.URL.Query()
|
|
service = strings.TrimSpace(query.Get("service"))
|
|
method = strings.TrimSpace(query.Get("method"))
|
|
mode = strings.ToLower(strings.TrimSpace(query.Get("mode")))
|
|
routingKey = strings.TrimSpace(query.Get("routing_key"))
|
|
if service == "" || method == "" {
|
|
http.Error(w, "missing rpc target", http.StatusBadRequest)
|
|
return "", "", "", "", false
|
|
}
|
|
if mode == "" {
|
|
mode = "proto"
|
|
}
|
|
return service, method, mode, routingKey, true
|
|
}
|
|
|
|
func (s *WorkerServer) ensureReady(ctx context.Context, routingKey string) (*Instance, error) {
|
|
if path := strings.TrimSpace(s.cfg.NetworkReadyFile); path != "" {
|
|
if _, err := os.Stat(path); err != nil {
|
|
return nil, fmt.Errorf("worker network not ready: %w", err)
|
|
}
|
|
}
|
|
|
|
s.applyState()
|
|
s.mu.RLock()
|
|
state := cloneWorkerAccountState(&s.state)
|
|
s.mu.RUnlock()
|
|
if state == nil || !state.HasToken || strings.TrimSpace(state.AccessToken) == "" {
|
|
return nil, fmt.Errorf("worker access token not configured")
|
|
}
|
|
|
|
inst, err := s.pool.GetOrCreate(s.cfg.AccountID, routingKey, "")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if inst.HasModelMappingUnavailable() {
|
|
return nil, fmt.Errorf("%w for replica %d: %s", errLSModelMapDenied, inst.Replica, inst.ModelMappingUnavailableReason())
|
|
}
|
|
if inst.HasModelMappingReady() {
|
|
return inst, nil
|
|
}
|
|
|
|
modelCtx, cancel := context.WithTimeout(ctx, lsModelConfigTimeout)
|
|
defer cancel()
|
|
_ = modelCtx
|
|
if !RefreshModelMapping(inst) {
|
|
if inst.HasModelMappingUnavailable() {
|
|
return nil, fmt.Errorf("%w for replica %d: %s", errLSModelMapDenied, inst.Replica, inst.ModelMappingUnavailableReason())
|
|
}
|
|
return nil, fmt.Errorf("worker model mapping not ready for replica %d", inst.Replica)
|
|
}
|
|
return inst, nil
|
|
}
|
|
|
|
func (s *WorkerServer) applyState() {
|
|
s.mu.RLock()
|
|
state := cloneWorkerAccountState(&s.state)
|
|
s.mu.RUnlock()
|
|
if state == nil {
|
|
return
|
|
}
|
|
if state.HasToken {
|
|
expiresAt := time.Time{}
|
|
if state.ExpiresAt != nil {
|
|
expiresAt = state.ExpiresAt.UTC()
|
|
}
|
|
s.pool.SetAccountToken(s.cfg.AccountID, state.AccessToken, state.RefreshToken, expiresAt)
|
|
}
|
|
if state.HasModelCredits {
|
|
s.pool.SetAccountModelCredits(s.cfg.AccountID, state.UseAICredits, state.AvailableCredits, state.MinimumCreditAmount)
|
|
}
|
|
}
|
|
|
|
func workerHTTPServer(listenAddr string, handler http.Handler) *http.Server {
|
|
return &http.Server{
|
|
Addr: listenAddr,
|
|
Handler: handler,
|
|
ReadHeaderTimeout: 10 * time.Second,
|
|
}
|
|
}
|
|
|
|
func workerExitCode(err error) int {
|
|
if err == nil {
|
|
return 0
|
|
}
|
|
return 1
|
|
}
|
|
|
|
func parseWorkerControlPort() int {
|
|
raw := strings.TrimSpace(os.Getenv("LSWORKER_CONTROL_PORT"))
|
|
if raw == "" {
|
|
return lsWorkerControlPort
|
|
}
|
|
port, err := strconv.Atoi(raw)
|
|
if err != nil || port < 1 {
|
|
return lsWorkerControlPort
|
|
}
|
|
return port
|
|
}
|