sub2api/backend/internal/service/language_server_service.go
win 3403e8401c fix: revert antigravity Forward to v1internal REST path, remove broken lsrpc upstream call
lsrpc is local IPC (IDE ↔ language_server binary), not an upstream protocol.
cloudcode-pa.googleapis.com does not serve gRPC/lsrpc endpoints.
Restores antigravityRetryLoop + streamGenerateContent path which was working.
Removes antigravity_lsrpc.go (upstream caller) and lsrpc_test cmd.
Keeps lsrpc_handler.go (server side, receives IDE connections).
2026-04-19 20:03:34 +08:00

531 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"strings"
"sync"
"time"
"github.com/google/uuid"
)
// CascadeSession 代表一个 Cascade Agent 会话
type CascadeSession struct {
ID string
ModelName string
Messages []map[string]interface{} // {role, content}
Metadata map[string]string // 设备指纹、User-Agent 等
Token string // OAuth token
CreatedAt int64
}
// LanguageServerService 业务逻辑层
// 处理 Cascade Agent 流程,通过 AntigravityGatewayService 转发到上游 API
type LanguageServerService struct {
// 会话管理
cascadeSessions map[string]*CascadeSession
sessionMutex sync.RWMutex
// 上游 HTTP 服务(用于发送请求)
httpUpstream HTTPUpstream
// Antigravity 网关(账号池调度 + TLS 指纹 + token 刷新)
antigravitySvc *AntigravityGatewayService
accountRepo AccountRepository
// 日志
logger *slog.Logger
// 改进 1: 速率限制 (令牌桶)
// 限制并发消息处理数量,保护上游 API
rateLimiter chan struct{}
// 改进 3: 会话过期时间 (秒)
sessionTTLSeconds int64
// 改进 3: 定期清理后台任务
cleanupTicker *time.Ticker
stopCleanup chan struct{}
}
func NewLanguageServerService(
logger *slog.Logger,
httpUpstream HTTPUpstream,
antigravitySvc *AntigravityGatewayService,
accountRepo AccountRepository,
) *LanguageServerService {
svc := &LanguageServerService{
cascadeSessions: make(map[string]*CascadeSession),
logger: logger,
httpUpstream: httpUpstream,
antigravitySvc: antigravitySvc,
accountRepo: accountRepo,
rateLimiter: make(chan struct{}, 100), // 改进 1: 限制 100 个并发消息
sessionTTLSeconds: 3600, // 改进 3: 会话默认 1 小时过期
stopCleanup: make(chan struct{}),
}
// 改进 3: 启动后台清理任务
svc.startSessionCleanup()
return svc
}
// startSessionCleanup 启动会话定期清理任务
func (svc *LanguageServerService) startSessionCleanup() {
svc.cleanupTicker = time.NewTicker(1 * time.Minute)
go func() {
for {
select {
case <-svc.cleanupTicker.C:
svc.cleanupExpiredSessions()
case <-svc.stopCleanup:
svc.cleanupTicker.Stop()
return
}
}
}()
}
// cleanupExpiredSessions 清理过期的会话
func (svc *LanguageServerService) cleanupExpiredSessions() {
now := getCurrentTimeMS()
ttlMs := svc.sessionTTLSeconds * 1000
svc.sessionMutex.Lock()
defer svc.sessionMutex.Unlock()
deletedCount := 0
for id, session := range svc.cascadeSessions {
if now-session.CreatedAt > ttlMs {
delete(svc.cascadeSessions, id)
deletedCount++
}
}
if deletedCount > 0 {
svc.logger.Info("expired sessions cleaned up",
"deleted_count", deletedCount,
"remaining_sessions", len(svc.cascadeSessions),
)
}
}
// Stop 优雅关闭服务
func (svc *LanguageServerService) Stop() {
select {
case svc.stopCleanup <- struct{}{}:
default:
}
}
// SetSessionTTL sets the session TTL for testing purposes
func (svc *LanguageServerService) SetSessionTTL(ttlSeconds int64) {
svc.sessionTTLSeconds = ttlSeconds
}
// GetCascadeSessions returns the current cascade sessions map (for testing)
func (svc *LanguageServerService) GetCascadeSessions() map[string]*CascadeSession {
svc.sessionMutex.RLock()
defer svc.sessionMutex.RUnlock()
return svc.cascadeSessions
}
// ============================================================================
// Cascade 业务逻辑
// ============================================================================
// StartCascade 启动新的 Cascade Agent 会话
func (svc *LanguageServerService) StartCascade(
ctx context.Context,
model string,
systemPrompt string,
metadata map[string]string,
token string,
) (string, error) {
// 1. 验证输入
if model == "" {
return "", fmt.Errorf("model is required")
}
if token == "" {
return "", fmt.Errorf("oauth token is required")
}
// 2. 生成会话 ID
sessionID := uuid.New().String()
// 3. 创建会话
session := &CascadeSession{
ID: sessionID,
ModelName: model,
Messages: make([]map[string]interface{}, 0),
Metadata: metadata,
Token: token,
CreatedAt: getCurrentTimeMS(),
}
// 如果提供了系统提示,添加为初始消息
if systemPrompt != "" {
session.Messages = append(session.Messages, map[string]interface{}{
"role": "user",
"content": systemPrompt,
})
}
// 4. 保存会话
svc.sessionMutex.Lock()
svc.cascadeSessions[sessionID] = session
svc.sessionMutex.Unlock()
svc.logger.Info("cascade session started",
"session_id", sessionID,
"model", model,
"has_system_prompt", systemPrompt != "")
return sessionID, nil
}
// SendUserMessage 发送用户消息到 Cascade
// 返回流式更新通道
func (svc *LanguageServerService) SendUserMessage(
ctx context.Context,
cascadeID string,
userMessage string,
token string,
) (<-chan interface{}, error) {
// 改进 1: 获取速率限制令牌
select {
case svc.rateLimiter <- struct{}{}:
// 获得令牌,继续
case <-ctx.Done():
return nil, fmt.Errorf("context cancelled")
default:
// 没有令牌,需要等待
select {
case svc.rateLimiter <- struct{}{}:
// 获得令牌
case <-ctx.Done():
return nil, fmt.Errorf("context cancelled while waiting for rate limit")
case <-time.After(30 * time.Second):
return nil, fmt.Errorf("rate limit timeout: too many concurrent messages")
}
}
// 1. 获取会话
svc.sessionMutex.RLock()
session, exists := svc.cascadeSessions[cascadeID]
svc.sessionMutex.RUnlock()
if !exists {
// 释放令牌
<-svc.rateLimiter
return nil, fmt.Errorf("cascade session not found: %s", cascadeID)
}
// 2. 验证 token
if token != session.Token {
// 释放令牌
<-svc.rateLimiter
return nil, fmt.Errorf("invalid token for session")
}
// 改进 2: 并发安全的消息追加(深拷贝消息列表)
svc.sessionMutex.Lock()
newMessages := make([]map[string]interface{}, len(session.Messages)+1)
copy(newMessages, session.Messages)
newMessages[len(newMessages)-1] = map[string]interface{}{
"role": "user",
"content": userMessage,
}
session.Messages = newMessages
svc.sessionMutex.Unlock()
// 4. 创建响应通道
updateChan := make(chan interface{}, 100)
// 5. 启动后台 goroutine 处理 API 调用
go func() {
defer func() {
// 关闭通道
close(updateChan)
// 改进 1: 释放速率限制令牌
<-svc.rateLimiter
}()
// 调用上游 API关键这里需要伪装
svc.callUpstreamAPI(ctx, session, updateChan)
}()
svc.logger.Info("user message sent to cascade",
"session_id", cascadeID,
"message_length", len(userMessage),
"concurrent_requests", 100-len(svc.rateLimiter), // 显示当前并发数
)
return updateChan, nil
}
// CancelCascade 取消 Cascade 会话
func (svc *LanguageServerService) CancelCascade(
ctx context.Context,
cascadeID string,
) error {
svc.sessionMutex.Lock()
_, exists := svc.cascadeSessions[cascadeID]
svc.sessionMutex.Unlock()
if !exists {
return fmt.Errorf("cascade session not found: %s", cascadeID)
}
// TODO: 取消正在进行的 API 调用
svc.logger.Info("cascade cancelled", "session_id", cascadeID)
return nil
}
// ============================================================================
// 模型配置
// ============================================================================
// ModelConfig 模型配置
type ModelConfig struct {
Name string
DisplayName string
MaxTokens int
SupportsThinking bool
ThinkingBudget int
SupportsImages bool
Provider string
}
// GetAvailableModels 获取可用模型列表
func (svc *LanguageServerService) GetAvailableModels(ctx context.Context) ([]ModelConfig, error) {
models := []ModelConfig{
{
Name: "claude-opus-4-7",
DisplayName: "Claude Opus 4.7",
MaxTokens: 200000,
SupportsThinking: true,
ThinkingBudget: 32000,
SupportsImages: true,
Provider: "anthropic",
},
{
Name: "claude-sonnet-4-7",
DisplayName: "Claude Sonnet 4.7",
MaxTokens: 200000,
SupportsThinking: true,
ThinkingBudget: 16000,
SupportsImages: true,
Provider: "anthropic",
},
{
Name: "claude-opus-4-6",
DisplayName: "Claude Opus 4.6",
MaxTokens: 200000,
SupportsThinking: true,
ThinkingBudget: 32000,
SupportsImages: true,
Provider: "anthropic",
},
{
Name: "claude-sonnet-4-6",
DisplayName: "Claude Sonnet 4.6",
MaxTokens: 200000,
SupportsThinking: false,
SupportsImages: true,
Provider: "anthropic",
},
{
Name: "claude-haiku-4-5",
DisplayName: "Claude Haiku 4.5",
MaxTokens: 200000,
SupportsThinking: false,
SupportsImages: true,
Provider: "anthropic",
},
{
Name: "gemini-3-pro",
DisplayName: "Gemini 3 Pro",
MaxTokens: 128000,
SupportsThinking: false,
SupportsImages: true,
Provider: "google",
},
}
return models, nil
}
// ============================================================================
// 状态查询
// ============================================================================
// GetStatus 获取服务状态
func (svc *LanguageServerService) GetStatus(ctx context.Context) (string, error) {
// TODO: 检查上游 API 连接状态
return "running", nil
}
// ============================================================================
// 内部方法
// ============================================================================
// callUpstreamAPI 通过 AntigravityGatewayService 调用上游 API。
// 复用账号池调度、模型映射、TLS 指纹伪装、token 刷新和重试逻辑。
func (svc *LanguageServerService) callUpstreamAPI(
ctx context.Context,
session *CascadeSession,
updateChan chan<- interface{},
) {
if svc.antigravitySvc == nil || svc.accountRepo == nil {
updateChan <- map[string]interface{}{
"type": "error",
"error": "antigravity gateway not configured",
}
return
}
// 1. 选取第一个可用的 Antigravity 账号
accounts, err := svc.accountRepo.ListByPlatform(ctx, PlatformAntigravity)
if err != nil || len(accounts) == 0 {
svc.logger.Error("no antigravity accounts available", "session_id", session.ID, "error", err)
updateChan <- map[string]interface{}{
"type": "error",
"error": "no antigravity accounts available",
}
return
}
account := &accounts[0]
// 2. 准备 Claude 格式请求体
requestBody := map[string]interface{}{
"model": session.ModelName,
"messages": session.Messages,
"stream": true,
"max_tokens": 8192,
}
bodyJSON, err := json.Marshal(requestBody)
if err != nil {
svc.logger.Error("failed to marshal request", "session_id", session.ID, "error", err)
updateChan <- map[string]interface{}{
"type": "error",
"error": "failed to prepare request",
}
return
}
svc.logger.Debug("forwarding via antigravity", "session_id", session.ID, "model", session.ModelName, "account_id", account.ID)
// 3. 通过 AntigravityGatewayService 转发(完整 TLS 指纹 + token 刷新 + 重试)
respBody, statusCode, err := svc.antigravitySvc.ForwardRaw(ctx, account, bodyJSON)
if err != nil {
svc.logger.Error("upstream request failed", "session_id", session.ID, "error", err)
updateChan <- map[string]interface{}{
"type": "error",
"error": fmt.Sprintf("upstream request failed: %v", err),
}
return
}
defer func() { _ = respBody.Close() }()
// 4. 处理错误响应
if statusCode >= 400 {
body, _ := io.ReadAll(io.LimitReader(respBody, 2<<20))
svc.logger.Error("upstream error response", "session_id", session.ID, "status_code", statusCode, "body", string(body))
updateChan <- map[string]interface{}{
"type": "error",
"status_code": statusCode,
"error": string(body),
}
return
}
// 5. 流式转发响应
svc.streamUpstreamResponse(ctx, session.ID, respBody, updateChan)
}
// streamUpstreamResponse 处理上游 SSE 流式响应
func (svc *LanguageServerService) streamUpstreamResponse(
ctx context.Context,
sessionID string,
body io.ReadCloser,
updateChan chan<- interface{},
) {
scanner := bufio.NewScanner(body)
// 设置合理的缓冲区大小
scanner.Buffer(make([]byte, 64*1024), 512*1024)
for scanner.Scan() {
select {
case <-ctx.Done():
svc.logger.Info("streaming cancelled", "session_id", sessionID)
return
default:
}
line := strings.TrimSpace(scanner.Text())
// 跳过空行
if line == "" {
continue
}
// 跳过注释行
if strings.HasPrefix(line, ":") {
continue
}
// 解析 SSE 格式 (data: {...})
if !strings.HasPrefix(line, "data:") {
continue
}
eventData := strings.TrimPrefix(line, "data:")
eventData = strings.TrimSpace(eventData)
// 解析 JSON
var event map[string]interface{}
if err := json.Unmarshal([]byte(eventData), &event); err != nil {
svc.logger.Debug("failed to parse event",
"session_id", sessionID,
"error", err,
"data", eventData,
)
continue
}
// 发送事件到客户端通道
select {
case updateChan <- event:
case <-ctx.Done():
return
case <-time.After(5 * time.Second):
svc.logger.Warn("channel send timeout",
"session_id", sessionID,
)
return
}
}
if err := scanner.Err(); err != nil {
svc.logger.Error("scanning upstream response failed",
"session_id", sessionID,
"error", err,
)
}
}
// getCurrentTimeMS 获取当前时间戳(毫秒)
func getCurrentTimeMS() int64 {
return time.Now().UnixMilli()
}