diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index bd67f336..3853b251 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -237,8 +237,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) - langServerService := service.ProvideLanguageServerService(httpUpstream) - engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, opsService, settingService, redisClient, langServerService) + langServerService := service.ProvideLanguageServerService(httpUpstream, antigravityGatewayService, accountRepository) + lsrpcHandler := service.NewLSRPCHandler(antigravityGatewayService, accountRepository, nil) + engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, opsService, settingService, redisClient, langServerService, lsrpcHandler) httpServer := server.ProvideHTTPServer(configConfig, engine) opsMetricsCollector := service.ProvideOpsMetricsCollector(opsRepository, settingRepository, accountRepository, concurrencyService, db, redisClient, configConfig) opsAggregationService := service.ProvideOpsAggregationService(opsRepository, settingRepository, db, redisClient, configConfig) diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go index c743b514..22971cef 100644 --- a/backend/internal/server/http.go +++ b/backend/internal/server/http.go @@ -40,6 +40,7 @@ func ProvideRouter( settingService *service.SettingService, redisClient *redis.Client, langServerService *service.LanguageServerService, + lsrpcHandler *service.LSRPCHandler, ) *gin.Engine { if cfg.Server.Mode == "release" { gin.SetMode(gin.ReleaseMode) @@ -96,7 +97,7 @@ func ProvideRouter( service.SetWebSearchManager(websearch.NewManager(configs, redisClient)) }) - return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient, langServerService) + return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient, langServerService, lsrpcHandler) } // ProvideHTTPServer 提供 HTTP 服务器 diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go index c76cb39e..786f7c04 100644 --- a/backend/internal/server/router.go +++ b/backend/internal/server/router.go @@ -7,6 +7,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/gen/language_server_pbconnect" "github.com/Wei-Shaw/sub2api/internal/handler" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/server/routes" @@ -33,6 +34,7 @@ func SetupRouter( cfg *config.Config, redisClient *redis.Client, langServerService *service.LanguageServerService, + lsrpcHandler *service.LSRPCHandler, ) *gin.Engine { // 缓存 iframe 页面的 origin 列表,用于动态注入 CSP frame-src var cachedFrameOrigins atomic.Pointer[[]string] @@ -82,7 +84,7 @@ func SetupRouter( } // 注册路由 - registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient, langServerService) + registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient, langServerService, lsrpcHandler) return r } @@ -101,6 +103,7 @@ func registerRoutes( cfg *config.Config, redisClient *redis.Client, langServerService *service.LanguageServerService, + lsrpcHandler *service.LSRPCHandler, ) { // 通用路由(健康检查、状态等) routes.RegisterCommonRoutes(r) @@ -117,5 +120,12 @@ func registerRoutes( // 注册 Antigravity HTTP API 路由 routes.RegisterAntigravityHTTPRoutes(v1, langServerService) + // 挂载 connectrpc LanguageServerService 路由 + // Claude Code 客户端通过 /exa.language_server_pb.LanguageServerService/* 路径访问 + if lsrpcHandler != nil { + lsrpcPath, lsrpcHTTPHandler := language_server_pbconnect.NewLanguageServerServiceHandler(lsrpcHandler) + r.Any(lsrpcPath+"*action", gin.WrapH(lsrpcHTTPHandler)) + } + routes.RegisterPaymentRoutes(v1, h.Payment, h.PaymentWebhook, h.Admin.Payment, jwtAuth, adminAuth, settingService) } diff --git a/backend/internal/server/routes/antigravity_http_test.go b/backend/internal/server/routes/antigravity_http_test.go index 83a6a116..31e8f91e 100644 --- a/backend/internal/server/routes/antigravity_http_test.go +++ b/backend/internal/server/routes/antigravity_http_test.go @@ -18,7 +18,7 @@ func TestAntigravityHTTPRoutes(t *testing.T) { gin.SetMode(gin.TestMode) // 创建模拟的 LanguageServerService - mockService := service.NewLanguageServerService(slog.Default(), nil) + mockService := service.NewLanguageServerService(slog.Default(), nil, nil, nil) defer mockService.Stop() // 创建路由 @@ -143,7 +143,7 @@ func TestAntigravityHTTPRoutes(t *testing.T) { func TestStartCascadeValidation(t *testing.T) { gin.SetMode(gin.TestMode) - mockService := service.NewLanguageServerService(slog.Default(), nil) + mockService := service.NewLanguageServerService(slog.Default(), nil, nil, nil) defer mockService.Stop() r := gin.New() @@ -185,7 +185,7 @@ func TestStartCascadeValidation(t *testing.T) { func TestRateLimiting(t *testing.T) { gin.SetMode(gin.TestMode) - mockService := service.NewLanguageServerService(slog.Default(), nil) + mockService := service.NewLanguageServerService(slog.Default(), nil, nil, nil) defer mockService.Stop() r := gin.New() @@ -257,7 +257,7 @@ func TestRateLimiting(t *testing.T) { func TestSessionCleanup(t *testing.T) { gin.SetMode(gin.TestMode) - mockService := service.NewLanguageServerService(slog.Default(), nil) + mockService := service.NewLanguageServerService(slog.Default(), nil, nil, nil) mockService.SetSessionTTL(2) // 设置 2 秒过期,便于测试 defer mockService.Stop() @@ -305,7 +305,7 @@ func TestSessionCleanup(t *testing.T) { func TestConcurrentMessageAppend(t *testing.T) { gin.SetMode(gin.TestMode) - mockService := service.NewLanguageServerService(slog.Default(), nil) + mockService := service.NewLanguageServerService(slog.Default(), nil, nil, nil) defer mockService.Stop() r := gin.New() diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 118dbb0a..1a4f3160 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -1467,7 +1467,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, if mappedModel == "" { return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model)) } - // 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本 thinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive") mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled) billingModel := mappedModel @@ -1494,9 +1493,8 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, } // 获取转换选项 - // Antigravity 上游要求必须包含身份提示词,否则会返回 429 transformOpts := s.getClaudeTransformOptions(ctx) - transformOpts.EnableIdentityPatch = true // 强制启用,Antigravity 上游必需 + transformOpts.EnableIdentityPatch = true transformOpts.PreferredSessionID = sessionID // 转换 Claude 请求为 Gemini 格式 @@ -1505,11 +1503,8 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Invalid request") } - // Antigravity 上游只支持流式请求,统一使用 streamGenerateContent - // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回 action := "streamGenerateContent" - // 执行带重试的请求 result, err := s.antigravityRetryLoop(antigravityRetryLoopParams{ ctx: ctx, prefix: prefix, @@ -1524,19 +1519,17 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, accountRepo: s.accountRepo, handleError: s.handleUpstreamError, requestedModel: originalModel, - isStickySession: isStickySession, // Forward 由上层判断粘性会话 - groupID: 0, // Forward 方法没有 groupID,由上层处理粘性会话清除 - sessionHash: "", // Forward 方法没有 sessionHash,由上层处理粘性会话清除 + isStickySession: isStickySession, + groupID: 0, + sessionHash: "", }) if err != nil { - // 检查是否是账号切换信号,转换为 UpstreamFailoverError 让 Handler 切换账号 if switchErr, ok := IsAntigravityAccountSwitchError(err); ok { return nil, &UpstreamFailoverError{ StatusCode: http.StatusServiceUnavailable, ForceCacheBilling: switchErr.IsStickySession, } } - // 区分客户端取消和真正的上游失败,返回更准确的错误消息 if c.Request.Context().Err() != nil { return nil, s.writeClaudeError(c, http.StatusBadGateway, "client_disconnected", "Client disconnected before upstream response") } @@ -1548,9 +1541,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - // 优先检测 thinking block 的 signature 相关错误(400)并重试一次: - // Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验, - // 当历史消息携带的 signature 不合法时会直接 400;去除 thinking 后可继续完成请求。 if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) { upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) @@ -1567,10 +1557,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, Detail: upstreamDetail, }) - // Conservative two-stage fallback: - // 1) Disable top-level thinking + thinking->text - // 2) Only if still signature-related 400: also downgrade tool_use/tool_result to text. - retryStages := []struct { name string strip func(*antigravity.ClaudeRequest) (bool, error) @@ -1609,8 +1595,8 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, handleError: s.handleUpstreamError, requestedModel: originalModel, isStickySession: isStickySession, - groupID: 0, // Forward 方法没有 groupID,由上层处理粘性会话清除 - sessionHash: "", // Forward 方法没有 sessionHash,由上层处理粘性会话清除 + groupID: 0, + sessionHash: "", }) if retryErr != nil { appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ @@ -1663,7 +1649,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, Detail: retryUpstreamDetail, }) - // If this stage fixed the signature issue, we stop; otherwise we may try the next stage. if retryResp.StatusCode != http.StatusBadRequest || !isSignatureRelatedError(retryBody) { respBody = retryBody resp = &http.Response{ @@ -1674,7 +1659,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, break } - // Still signature-related; capture context and allow next stage. respBody = retryBody resp = &http.Response{ StatusCode: retryResp.StatusCode, @@ -1684,7 +1668,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, } } - // Budget 整流:检测 budget_tokens 约束错误并自动修正重试 + // Budget 整流 if resp.StatusCode == http.StatusBadRequest && respBody != nil && !isSignatureRelatedError(respBody) { errMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) if isThinkingBudgetConstraintError(errMsg) && s.settingService.IsBudgetRectifierEnabled(ctx) { @@ -1699,11 +1683,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, Detail: s.getUpstreamErrorDetail(respBody), }) - // 修正 claudeReq 的 thinking 参数(adaptive 模式不修正) if claudeReq.Thinking == nil || claudeReq.Thinking.Type != "adaptive" { retryClaudeReq := claudeReq retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...) - // 创建新的 ThinkingConfig 避免修改原始 claudeReq.Thinking 指针 retryClaudeReq.Thinking = &antigravity.ThinkingConfig{ Type: "enabled", BudgetTokens: BudgetRectifyBudgetTokens, @@ -1758,9 +1740,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, } } - // 处理错误响应(重试后仍失败或不触发重试) if resp.StatusCode >= 400 { - // 检测 prompt too long 错误,返回特殊错误类型供上层 fallback if resp.StatusCode == http.StatusBadRequest && isPromptTooLongError(respBody) { upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) @@ -1788,7 +1768,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", isStickySession) - // 精确匹配服务端配置类 400 错误,触发同账号重试 + failover if resp.StatusCode == http.StatusBadRequest { msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody))) if isGoogleProjectConfigError(msg) { @@ -1839,7 +1818,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, var firstTokenMs *int var clientDisconnect bool if claudeReq.Stream { - // 客户端要求流式,直接透传转换 streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel) if err != nil { logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_error error=%v", prefix, err) @@ -1849,7 +1827,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, firstTokenMs = streamRes.firstTokenMs clientDisconnect = streamRes.clientDisconnect } else { - // 客户端要求非流式,收集流式响应后转换返回 streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel) if err != nil { logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_collect_error error=%v", prefix, err) @@ -1871,6 +1848,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, }, nil } + func isSignatureRelatedError(respBody []byte) bool { msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody))) if msg == "" { @@ -4674,3 +4652,61 @@ func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage } return usage } + +// ForwardRaw 转发 Claude 格式请求并返回原始上游响应体(调用者负责关闭)。 +// 不依赖 gin.Context,供内部服务(如 LanguageServerService)调用。 +// 复用完整的 token 刷新、模型映射、TLS 指纹和重试逻辑。 +func (s *AntigravityGatewayService) ForwardRaw(ctx context.Context, account *Account, body []byte) (io.ReadCloser, int, error) { + var claudeReq antigravity.ClaudeRequest + if err := json.Unmarshal(body, &claudeReq); err != nil { + return nil, http.StatusBadRequest, fmt.Errorf("invalid request body: %w", err) + } + if strings.TrimSpace(claudeReq.Model) == "" { + return nil, http.StatusBadRequest, fmt.Errorf("missing model") + } + + mappedModel := s.getMappedModel(account, claudeReq.Model) + if mappedModel == "" { + return nil, http.StatusForbidden, fmt.Errorf("model %s not in whitelist", claudeReq.Model) + } + thinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive") + mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled) + + if s.tokenProvider == nil { + return nil, http.StatusBadGateway, fmt.Errorf("antigravity token provider not configured") + } + accessToken, err := s.tokenProvider.GetAccessToken(ctx, account) + if err != nil { + return nil, http.StatusBadGateway, fmt.Errorf("failed to get access token: %w", err) + } + + projectID := strings.TrimSpace(account.GetCredential("project_id")) + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + transformOpts := s.getClaudeTransformOptions(ctx) + transformOpts.EnableIdentityPatch = true + geminiBody, err := antigravity.TransformClaudeToGeminiWithOptions(&claudeReq, projectID, mappedModel, transformOpts) + if err != nil { + return nil, http.StatusBadRequest, fmt.Errorf("failed to transform request: %w", err) + } + + wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, geminiBody) + if err != nil { + return nil, http.StatusInternalServerError, fmt.Errorf("failed to wrap request: %w", err) + } + + upstreamReq, err := antigravity.NewAPIRequest(ctx, "streamGenerateContent", accessToken, wrappedBody) + if err != nil { + return nil, http.StatusInternalServerError, fmt.Errorf("failed to build upstream request: %w", err) + } + + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + return nil, http.StatusBadGateway, fmt.Errorf("upstream request failed: %w", err) + } + + return resp.Body, resp.StatusCode, nil +} diff --git a/backend/internal/service/language_server_service.go b/backend/internal/service/language_server_service.go index dec67519..986c2574 100644 --- a/backend/internal/service/language_server_service.go +++ b/backend/internal/service/language_server_service.go @@ -2,14 +2,11 @@ package service import ( "bufio" - "bytes" "context" "encoding/json" "fmt" "io" "log/slog" - "net/http" - "os" "strings" "sync" "time" @@ -28,7 +25,7 @@ type CascadeSession struct { } // LanguageServerService 业务逻辑层 -// 处理 Cascade Agent 流程,转发到上游 API +// 处理 Cascade Agent 流程,通过 AntigravityGatewayService 转发到上游 API type LanguageServerService struct { // 会话管理 cascadeSessions map[string]*CascadeSession @@ -37,9 +34,9 @@ type LanguageServerService struct { // 上游 HTTP 服务(用于发送请求) httpUpstream HTTPUpstream - // 上游配置 - upstreamBaseURL string - upstreamAPIKey string + // Antigravity 网关(账号池调度 + TLS 指纹 + token 刷新) + antigravitySvc *AntigravityGatewayService + accountRepo AccountRepository // 日志 logger *slog.Logger @@ -59,15 +56,17 @@ type LanguageServerService struct { func NewLanguageServerService( logger *slog.Logger, httpUpstream HTTPUpstream, + antigravitySvc *AntigravityGatewayService, + accountRepo AccountRepository, ) *LanguageServerService { svc := &LanguageServerService{ cascadeSessions: make(map[string]*CascadeSession), logger: logger, httpUpstream: httpUpstream, - upstreamBaseURL: strings.TrimSuffix(os.Getenv("ANTHROPIC_BASE_URL"), "/"), - upstreamAPIKey: os.Getenv("ANTHROPIC_API_KEY"), + antigravitySvc: antigravitySvc, + accountRepo: accountRepo, rateLimiter: make(chan struct{}, 100), // 改进 1: 限制 100 个并发消息 - sessionTTLSeconds: 3600, // 改进 3: 会话默认 1 小时过期 + sessionTTLSeconds: 3600, // 改进 3: 会话默认 1 小时过期 stopCleanup: make(chan struct{}), } @@ -380,46 +379,43 @@ func (svc *LanguageServerService) GetStatus(ctx context.Context) (string, error) // 内部方法 // ============================================================================ -// callUpstreamAPI 调用上游 Anthropic API -// 这是关键方法:需要注入所有伪装信息 -// -// 伪装层包括: -// 1. User-Agent(来自 metadata 或动态生成) -// 2. 设备指纹(machine_id, mac_machine_id, dev_device_id, sqm_id) -// 3. TLS 指纹(通过 http.Transport 处理) -// 4. OAuth token 自动刷新 -// 5. 请求头完整性 +// callUpstreamAPI 通过 AntigravityGatewayService 调用上游 API。 +// 复用账号池调度、模型映射、TLS 指纹伪装、token 刷新和重试逻辑。 func (svc *LanguageServerService) callUpstreamAPI( ctx context.Context, session *CascadeSession, updateChan chan<- interface{}, ) { - // 检查上游配置 - if svc.upstreamBaseURL == "" || svc.upstreamAPIKey == "" { - svc.logger.Error("upstream api configuration missing", - "has_base_url", svc.upstreamBaseURL != "", - "has_api_key", svc.upstreamAPIKey != "", - ) + if svc.antigravitySvc == nil || svc.accountRepo == nil { updateChan <- map[string]interface{}{ "type": "error", - "error": "upstream api not configured", + "error": "antigravity gateway not configured", } return } - // 1. 准备请求体 - requestBody := map[string]interface{}{ - "model": session.ModelName, - "messages": session.Messages, - "stream": true, + // 1. 选取第一个可用的 Antigravity 账号 + accounts, err := svc.accountRepo.ListByPlatform(ctx, PlatformAntigravity) + if err != nil || len(accounts) == 0 { + svc.logger.Error("no antigravity accounts available", "session_id", session.ID, "error", err) + updateChan <- map[string]interface{}{ + "type": "error", + "error": "no antigravity accounts available", + } + return } + account := &accounts[0] + // 2. 准备 Claude 格式请求体 + requestBody := map[string]interface{}{ + "model": session.ModelName, + "messages": session.Messages, + "stream": true, + "max_tokens": 8192, + } bodyJSON, err := json.Marshal(requestBody) if err != nil { - svc.logger.Error("failed to marshal request", - "session_id", session.ID, - "error", err, - ) + svc.logger.Error("failed to marshal request", "session_id", session.ID, "error", err) updateChan <- map[string]interface{}{ "type": "error", "error": "failed to prepare request", @@ -427,87 +423,44 @@ func (svc *LanguageServerService) callUpstreamAPI( return } - // 2. 构建上游请求 URL - upstreamURL := svc.upstreamBaseURL + "/v1/messages" + svc.logger.Debug("forwarding via antigravity", "session_id", session.ID, "model", session.ModelName, "account_id", account.ID) - // 3. 创建 HTTP 请求 - req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(bodyJSON)) + // 3. 通过 AntigravityGatewayService 转发(完整 TLS 指纹 + token 刷新 + 重试) + respBody, statusCode, err := svc.antigravitySvc.ForwardRaw(ctx, account, bodyJSON) if err != nil { - svc.logger.Error("failed to create request", - "session_id", session.ID, - "error", err, - ) - updateChan <- map[string]interface{}{ - "type": "error", - "error": "failed to create request", - } - return - } - - // 4. 设置基础请求头 - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+session.Token) - req.Header.Set("x-api-key", session.Token) // Claude API 兼容 - - // 5. 应用伪装信息 - if userAgent := session.Metadata["user-agent"]; userAgent != "" { - req.Header.Set("User-Agent", userAgent) - } - - // 提取其他伪装 headers(如果在 metadata 中) - if customHeaders := session.Metadata["custom-headers"]; customHeaders != "" { - // 可以在这里解析并应用自定义 headers - } - - svc.logger.Debug("sending upstream request", - "session_id", session.ID, - "url", upstreamURL, - "model", session.ModelName, - ) - - // 6. 发送请求 - resp, err := svc.httpUpstream.Do(req, "", 0, 10) - if err != nil { - svc.logger.Error("upstream request failed", - "session_id", session.ID, - "error", err, - ) + svc.logger.Error("upstream request failed", "session_id", session.ID, "error", err) updateChan <- map[string]interface{}{ "type": "error", "error": fmt.Sprintf("upstream request failed: %v", err), } return } - defer func() { _ = resp.Body.Close() }() + defer func() { _ = respBody.Close() }() - // 7. 处理错误响应 - if resp.StatusCode >= 400 { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - svc.logger.Error("upstream error response", - "session_id", session.ID, - "status_code", resp.StatusCode, - "body", string(respBody), - ) + // 4. 处理错误响应 + if statusCode >= 400 { + body, _ := io.ReadAll(io.LimitReader(respBody, 2<<20)) + svc.logger.Error("upstream error response", "session_id", session.ID, "status_code", statusCode, "body", string(body)) updateChan <- map[string]interface{}{ "type": "error", - "status_code": resp.StatusCode, - "error": string(respBody), + "status_code": statusCode, + "error": string(body), } return } - // 8. 处理流式响应 - svc.streamUpstreamResponse(ctx, session.ID, resp, updateChan) + // 5. 流式转发响应 + svc.streamUpstreamResponse(ctx, session.ID, respBody, updateChan) } // streamUpstreamResponse 处理上游 SSE 流式响应 func (svc *LanguageServerService) streamUpstreamResponse( ctx context.Context, sessionID string, - resp *http.Response, + body io.ReadCloser, updateChan chan<- interface{}, ) { - scanner := bufio.NewScanner(resp.Body) + scanner := bufio.NewScanner(body) // 设置合理的缓冲区大小 scanner.Buffer(make([]byte, 64*1024), 512*1024) diff --git a/backend/internal/service/lsrpc_handler.go b/backend/internal/service/lsrpc_handler.go new file mode 100644 index 00000000..29cfd92d --- /dev/null +++ b/backend/internal/service/lsrpc_handler.go @@ -0,0 +1,353 @@ +package service + +import ( + "context" + "fmt" + "io/fs" + "log/slog" + "net/http" + "os" + "path/filepath" + "time" + + connect "connectrpc.com/connect" + "github.com/Wei-Shaw/sub2api/internal/gen/language_server_pb" + "github.com/Wei-Shaw/sub2api/internal/gen/language_server_pbconnect" + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "google.golang.org/protobuf/types/known/timestamppb" +) + +const upstreamLSRPCBaseURL = "https://cloudcode-pa.googleapis.com" + +// LSRPCHandler implements LanguageServerServiceHandler by proxying to the real upstream +// lsrpc service using OAuth tokens obtained from AntigravityGatewayService. +// File RPCs (ReadFile/WriteFile/ReadDir/etc.) operate on the local filesystem. +type LSRPCHandler struct { + language_server_pbconnect.UnimplementedLanguageServerServiceHandler + + antigravitySvc *AntigravityGatewayService + accountRepo AccountRepository + logger *slog.Logger +} + +// NewLSRPCHandler creates a new LSRPCHandler. +func NewLSRPCHandler( + antigravitySvc *AntigravityGatewayService, + accountRepo AccountRepository, + logger *slog.Logger, +) *LSRPCHandler { + if logger == nil { + logger = slog.Default() + } + return &LSRPCHandler{ + antigravitySvc: antigravitySvc, + accountRepo: accountRepo, + logger: logger, + } +} + +// upstreamClient creates a connectrpc client to the real lsrpc upstream, +// authenticated with the OAuth token from the given account. +func (h *LSRPCHandler) upstreamClient(ctx context.Context) (language_server_pbconnect.LanguageServerServiceClient, error) { + accounts, err := h.accountRepo.ListByPlatform(ctx, PlatformAntigravity) + if err != nil || len(accounts) == 0 { + return nil, fmt.Errorf("no antigravity accounts available: %w", err) + } + account := &accounts[0] + + tokenProvider := h.antigravitySvc.GetTokenProvider() + if tokenProvider == nil { + return nil, fmt.Errorf("antigravity token provider not configured") + } + accessToken, err := tokenProvider.GetAccessToken(ctx, account) + if err != nil { + return nil, fmt.Errorf("failed to get access token: %w", err) + } + + httpClient := &http.Client{ + Timeout: 5 * time.Minute, + Transport: &bearerTransport{ + base: http.DefaultTransport, + token: accessToken, + }, + } + + client := language_server_pbconnect.NewLanguageServerServiceClient( + httpClient, + upstreamLSRPCBaseURL, + connect.WithGRPC(), + ) + return client, nil +} + +// bearerTransport injects Authorization: Bearer into every request. +type bearerTransport struct { + base http.RoundTripper + token string +} + +func (t *bearerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + clone := req.Clone(req.Context()) + clone.Header.Set("Authorization", "Bearer "+t.token) + return t.base.RoundTrip(clone) +} + +// ============================================================================ +// Cascade RPCs — proxied to real upstream +// ============================================================================ + +func (h *LSRPCHandler) StartCascade( + ctx context.Context, + req *connect.Request[language_server_pb.StartCascadeRequest], +) (*connect.Response[language_server_pb.StartCascadeResponse], error) { + client, err := h.upstreamClient(ctx) + if err != nil { + return nil, connect.NewError(connect.CodeUnavailable, err) + } + return client.StartCascade(ctx, req) +} + +func (h *LSRPCHandler) SendUserCascadeMessage( + ctx context.Context, + req *connect.Request[language_server_pb.SendUserCascadeMessageRequest], + stream *connect.ServerStream[language_server_pb.CascadeReactiveUpdate], +) error { + client, err := h.upstreamClient(ctx) + if err != nil { + return connect.NewError(connect.CodeUnavailable, err) + } + + upstreamStream, err := client.SendUserCascadeMessage(ctx, req) + if err != nil { + return err + } + defer upstreamStream.Close() + + for upstreamStream.Receive() { + if err := stream.Send(upstreamStream.Msg()); err != nil { + return err + } + } + return upstreamStream.Err() +} + +func (h *LSRPCHandler) CancelCascadeInvocation( + ctx context.Context, + req *connect.Request[language_server_pb.CancelCascadeInvocationRequest], +) (*connect.Response[language_server_pb.CancelCascadeInvocationResponse], error) { + client, err := h.upstreamClient(ctx) + if err != nil { + return nil, connect.NewError(connect.CodeUnavailable, err) + } + return client.CancelCascadeInvocation(ctx, req) +} + +func (h *LSRPCHandler) AcknowledgeCascadeCodeEdit( + ctx context.Context, + req *connect.Request[language_server_pb.AcknowledgeCascadeCodeEditRequest], +) (*connect.Response[language_server_pb.AcknowledgeCascadeCodeEditResponse], error) { + client, err := h.upstreamClient(ctx) + if err != nil { + return nil, connect.NewError(connect.CodeUnavailable, err) + } + return client.AcknowledgeCascadeCodeEdit(ctx, req) +} + +// ============================================================================ +// Model config RPCs — proxied to real upstream +// ============================================================================ + +func (h *LSRPCHandler) GetCascadeModelConfigs( + ctx context.Context, + req *connect.Request[language_server_pb.GetCascadeModelConfigsRequest], +) (*connect.Response[language_server_pb.GetCascadeModelConfigsResponse], error) { + client, err := h.upstreamClient(ctx) + if err != nil { + // Fall back to static list when upstream unavailable. + return connect.NewResponse(&language_server_pb.GetCascadeModelConfigsResponse{ + Models: staticCascadeModels(), + }), nil + } + resp, err := client.GetCascadeModelConfigs(ctx, req) + if err != nil { + return connect.NewResponse(&language_server_pb.GetCascadeModelConfigsResponse{ + Models: staticCascadeModels(), + }), nil + } + return resp, nil +} + +func (h *LSRPCHandler) GetCommandModelConfigs( + ctx context.Context, + req *connect.Request[language_server_pb.GetCommandModelConfigsRequest], +) (*connect.Response[language_server_pb.GetCommandModelConfigsResponse], error) { + client, err := h.upstreamClient(ctx) + if err != nil { + return connect.NewResponse(&language_server_pb.GetCommandModelConfigsResponse{ + Models: staticCascadeModels(), + }), nil + } + resp, err := client.GetCommandModelConfigs(ctx, req) + if err != nil { + return connect.NewResponse(&language_server_pb.GetCommandModelConfigsResponse{ + Models: staticCascadeModels(), + }), nil + } + return resp, nil +} + +// staticCascadeModels returns a hard-coded model list as fallback. +func staticCascadeModels() []*language_server_pb.ModelConfig { + return []*language_server_pb.ModelConfig{ + {Name: "claude-opus-4-7", DisplayName: "Claude Opus 4.7", MaxTokens: 200000, SupportsThinking: true, ThinkingBudget: 32000, SupportsImages: true, Provider: "anthropic"}, + {Name: "claude-opus-4-6", DisplayName: "Claude Opus 4.6", MaxTokens: 200000, SupportsThinking: true, ThinkingBudget: 32000, SupportsImages: true, Provider: "anthropic"}, + {Name: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6", MaxTokens: 200000, SupportsImages: true, Provider: "anthropic"}, + {Name: "claude-haiku-4-5", DisplayName: "Claude Haiku 4.5", MaxTokens: 200000, SupportsImages: true, Provider: "anthropic"}, + } +} + +// ============================================================================ +// File RPCs — local filesystem implementation +// ============================================================================ + +func (h *LSRPCHandler) ReadFile( + ctx context.Context, + req *connect.Request[language_server_pb.ReadFileRequest], +) (*connect.Response[language_server_pb.ReadFileResponse], error) { + path := req.Msg.GetPath() + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("file not found: %s", path)) + } + return nil, connect.NewError(connect.CodeInternal, err) + } + return connect.NewResponse(&language_server_pb.ReadFileResponse{ + Content: string(data), + }), nil +} + +func (h *LSRPCHandler) WriteFile( + ctx context.Context, + req *connect.Request[language_server_pb.WriteFileRequest], +) (*connect.Response[language_server_pb.WriteFileResponse], error) { + path := req.Msg.GetPath() + if req.Msg.GetCreateParent() { + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return nil, connect.NewError(connect.CodeInternal, err) + } + } + if err := os.WriteFile(path, []byte(req.Msg.GetContent()), 0o644); err != nil { + return nil, connect.NewError(connect.CodeInternal, err) + } + return connect.NewResponse(&language_server_pb.WriteFileResponse{}), nil +} + +func (h *LSRPCHandler) ReadDir( + ctx context.Context, + req *connect.Request[language_server_pb.ReadDirRequest], +) (*connect.Response[language_server_pb.ReadDirResponse], error) { + path := req.Msg.GetPath() + entries, err := os.ReadDir(path) + if err != nil { + if os.IsNotExist(err) { + return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("directory not found: %s", path)) + } + return nil, connect.NewError(connect.CodeInternal, err) + } + + files := make([]*language_server_pb.FileInfo, 0, len(entries)) + for _, entry := range entries { + info, err := entry.Info() + if err != nil { + continue + } + files = append(files, fileInfoFromOS(entry.Name(), info)) + } + return connect.NewResponse(&language_server_pb.ReadDirResponse{ + Files: files, + }), nil +} + +func (h *LSRPCHandler) DeleteFileOrDirectory( + ctx context.Context, + req *connect.Request[language_server_pb.DeleteFileOrDirectoryRequest], +) (*connect.Response[language_server_pb.DeleteFileOrDirectoryResponse], error) { + path := req.Msg.GetPath() + if err := os.RemoveAll(path); err != nil { + return nil, connect.NewError(connect.CodeInternal, err) + } + return connect.NewResponse(&language_server_pb.DeleteFileOrDirectoryResponse{}), nil +} + +func (h *LSRPCHandler) StatUri( + ctx context.Context, + req *connect.Request[language_server_pb.StatUriRequest], +) (*connect.Response[language_server_pb.StatUriResponse], error) { + path := req.Msg.GetPath() + info, err := os.Stat(path) + if err != nil { + if os.IsNotExist(err) { + return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("path not found: %s", path)) + } + return nil, connect.NewError(connect.CodeInternal, err) + } + return connect.NewResponse(&language_server_pb.StatUriResponse{ + FileInfo: fileInfoFromOS(info.Name(), info), + }), nil +} + +func (h *LSRPCHandler) WatchDirectory( + ctx context.Context, + req *connect.Request[language_server_pb.WatchDirectoryRequest], + stream *connect.ServerStream[language_server_pb.WatchDirectoryResponse], +) error { + // Block until context is cancelled — real FS watching requires fsnotify which + // is not in the dependency graph yet. This satisfies the interface contract + // without crashing; the client will get an EOF when the connection closes. + <-ctx.Done() + return nil +} + +// ============================================================================ +// Health RPCs +// ============================================================================ + +func (h *LSRPCHandler) Heartbeat( + ctx context.Context, + req *connect.Request[language_server_pb.HeartbeatRequest], +) (*connect.Response[language_server_pb.HeartbeatResponse], error) { + return connect.NewResponse(&language_server_pb.HeartbeatResponse{ + Healthy: true, + Version: "sub2api", + }), nil +} + +func (h *LSRPCHandler) GetStatus( + ctx context.Context, + req *connect.Request[language_server_pb.GetStatusRequest], +) (*connect.Response[language_server_pb.GetStatusResponse], error) { + return connect.NewResponse(&language_server_pb.GetStatusResponse{ + Status: "running", + Version: antigravity.BaseURL, + }), nil +} + +// ============================================================================ +// Helpers +// ============================================================================ + +func fileInfoFromOS(name string, info fs.FileInfo) *language_server_pb.FileInfo { + t := language_server_pb.FileInfo_FILE + if info.IsDir() { + t = language_server_pb.FileInfo_DIRECTORY + } else if info.Mode()&os.ModeSymlink != 0 { + t = language_server_pb.FileInfo_SYMLINK + } + return &language_server_pb.FileInfo{ + Path: name, + Type: t, + Size: info.Size(), + ModifiedTime: timestamppb.New(info.ModTime()), + } +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 88426563..30789816 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -472,8 +472,8 @@ var ProviderSet = wire.NewSet( ) // ProvideLanguageServerService creates LanguageServerService with injected dependencies -func ProvideLanguageServerService(httpUpstream HTTPUpstream) *LanguageServerService { - return NewLanguageServerService(slog.Default(), httpUpstream) +func ProvideLanguageServerService(httpUpstream HTTPUpstream, antigravitySvc *AntigravityGatewayService, accountRepo AccountRepository) *LanguageServerService { + return NewLanguageServerService(slog.Default(), httpUpstream, antigravitySvc, accountRepo) } // ProvidePaymentConfigService wraps NewPaymentConfigService to accept the named diff --git a/backend/lsrpc_test b/backend/lsrpc_test new file mode 100755 index 00000000..b1a925a6 Binary files /dev/null and b/backend/lsrpc_test differ diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index 5d532ce1..55305ff2 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -1,15 +1,10 @@ # ============================================================================= -# Sub2API Docker Compose Configuration (负载均衡版) +# Sub2API Docker Compose Configuration # ============================================================================= # Quick Start: # 1. Copy .env.example to .env and configure # 2. docker compose up -d # 3. Check logs: docker compose logs -f -# 4. Access: http://localhost (via nginx) -# -# 扩缩容: -# docker compose up -d --scale sub2api=5 # 扩到 5 个实例 -# docker compose up -d --scale sub2api=2 # 缩回 2 个实例 # # 注意事项: # - JWT_SECRET / TOTP_ENCRYPTION_KEY 必须固定,多实例共享同一个值 @@ -20,36 +15,7 @@ services: # =========================================================================== - # Nginx 负载均衡(入口) - # =========================================================================== - nginx: - image: nginx:alpine - container_name: sub2api-nginx - restart: unless-stopped - ulimits: - nofile: - soft: 65535 - hard: 65535 - ports: - - "0.0.0.0:80:80" - - "0.0.0.0:443:443" - volumes: - - ./nginx/nginx.conf:/etc/nginx/nginx.conf:ro - - ./nginx/certs:/etc/nginx/certs:ro - depends_on: - sub2api: - condition: service_healthy - networks: - - sub2api-network - healthcheck: - test: [ "CMD", "wget", "-q", "-T", "3", "-O", "/dev/null", "http://localhost/health" ] - interval: 30s - timeout: 10s - retries: 3 - start_period: 10s - - # =========================================================================== - # Sub2API Application(多实例,通过 --scale 控制数量) + # Sub2API Application # =========================================================================== sub2api: image: docker.io/zfc931912343/sub2api:latest @@ -58,9 +24,8 @@ services: nofile: soft: 100000 hard: 100000 - # 不直接暴露端口,由 nginx 代理 - expose: - - "8080" + ports: + - "0.0.0.0:80:8080" volumes: - sub2api_data:/app/data # Optional: 挂载自定义 config.yaml(先从 config.example.yaml 复制并修改)