263 lines
7.1 KiB
Go
263 lines
7.1 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/lspool"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type fakeLSBootstrapAccountReader struct {
|
|
mu sync.Mutex
|
|
accounts []Account
|
|
err error
|
|
platforms []string
|
|
}
|
|
|
|
func (f *fakeLSBootstrapAccountReader) ListByPlatform(_ context.Context, platform string) ([]Account, error) {
|
|
f.mu.Lock()
|
|
f.platforms = append(f.platforms, platform)
|
|
accounts := append([]Account(nil), f.accounts...)
|
|
err := f.err
|
|
f.mu.Unlock()
|
|
return accounts, err
|
|
}
|
|
|
|
type fakeLSPoolBackend struct {
|
|
mu sync.Mutex
|
|
tokenCalls map[string]fakeLSPoolTokenCall
|
|
creditCalls map[string]fakeLSPoolCreditCall
|
|
getCalls []fakeLSPoolGetCall
|
|
getErrs map[string]error
|
|
}
|
|
|
|
type fakeLSPoolTokenCall struct {
|
|
AccessToken string
|
|
RefreshToken string
|
|
ExpiresAt time.Time
|
|
}
|
|
|
|
type fakeLSPoolCreditCall struct {
|
|
UseAICredits bool
|
|
AvailableCredits *int32
|
|
MinimumCreditAmount *int32
|
|
}
|
|
|
|
type fakeLSPoolGetCall struct {
|
|
AccountID string
|
|
RoutingKey string
|
|
ProxyURL string
|
|
}
|
|
|
|
func newFakeLSPoolBackend() *fakeLSPoolBackend {
|
|
return &fakeLSPoolBackend{
|
|
tokenCalls: make(map[string]fakeLSPoolTokenCall),
|
|
creditCalls: make(map[string]fakeLSPoolCreditCall),
|
|
getErrs: make(map[string]error),
|
|
}
|
|
}
|
|
|
|
func (f *fakeLSPoolBackend) GetOrCreate(accountID, routingKey string, proxyURL ...string) (*lspool.Instance, error) {
|
|
rawProxy := ""
|
|
if len(proxyURL) > 0 {
|
|
rawProxy = proxyURL[0]
|
|
}
|
|
|
|
f.mu.Lock()
|
|
defer f.mu.Unlock()
|
|
f.getCalls = append(f.getCalls, fakeLSPoolGetCall{
|
|
AccountID: accountID,
|
|
RoutingKey: routingKey,
|
|
ProxyURL: rawProxy,
|
|
})
|
|
if err := f.getErrs[accountID]; err != nil {
|
|
return nil, err
|
|
}
|
|
return &lspool.Instance{AccountID: accountID}, nil
|
|
}
|
|
|
|
func (f *fakeLSPoolBackend) SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time) {
|
|
f.mu.Lock()
|
|
defer f.mu.Unlock()
|
|
f.tokenCalls[accountID] = fakeLSPoolTokenCall{
|
|
AccessToken: accessToken,
|
|
RefreshToken: refreshToken,
|
|
ExpiresAt: expiresAt,
|
|
}
|
|
}
|
|
|
|
func (f *fakeLSPoolBackend) SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32) {
|
|
f.mu.Lock()
|
|
defer f.mu.Unlock()
|
|
f.creditCalls[accountID] = fakeLSPoolCreditCall{
|
|
UseAICredits: useAICredits,
|
|
AvailableCredits: copyInt32Ptr(availableCredits),
|
|
MinimumCreditAmount: copyInt32Ptr(minimumCreditAmountForUsage),
|
|
}
|
|
}
|
|
|
|
func (f *fakeLSPoolBackend) Stats() map[string]any { return nil }
|
|
|
|
func (f *fakeLSPoolBackend) Close() {}
|
|
|
|
func copyInt32Ptr(v *int32) *int32 {
|
|
if v == nil {
|
|
return nil
|
|
}
|
|
cp := *v
|
|
return &cp
|
|
}
|
|
|
|
func TestLSPoolBootstrapServiceBootstrapEligibleAccounts(t *testing.T) {
|
|
expiresAt := time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second)
|
|
expiredAt := time.Now().Add(-2 * time.Hour)
|
|
reader := &fakeLSBootstrapAccountReader{
|
|
accounts: []Account{
|
|
{
|
|
ID: 101,
|
|
Platform: PlatformAntigravity,
|
|
Type: AccountTypeOAuth,
|
|
Status: StatusActive,
|
|
Schedulable: true,
|
|
Credentials: map[string]any{
|
|
"access_token": "token-101",
|
|
"refresh_token": "refresh-101",
|
|
"expires_at": expiresAt.Format(time.RFC3339),
|
|
"project_id": "proj-101",
|
|
},
|
|
Extra: map[string]any{
|
|
"allow_overages": true,
|
|
"ai_credits": []any{
|
|
map[string]any{
|
|
"credit_type": "GOOGLE_ONE_AI",
|
|
"amount": 120,
|
|
"minimum_balance": 55,
|
|
},
|
|
},
|
|
},
|
|
Proxy: &Proxy{
|
|
Protocol: "socks5h",
|
|
Host: "127.0.0.1",
|
|
Port: 1080,
|
|
Username: "alice",
|
|
Password: "secret",
|
|
},
|
|
},
|
|
{
|
|
ID: 102,
|
|
Platform: PlatformAntigravity,
|
|
Type: AccountTypeOAuth,
|
|
Status: StatusActive,
|
|
Schedulable: false,
|
|
Credentials: map[string]any{"access_token": "token-102", "project_id": "proj-102"},
|
|
},
|
|
{
|
|
ID: 103,
|
|
Platform: PlatformAntigravity,
|
|
Type: AccountTypeOAuth,
|
|
Status: StatusActive,
|
|
Schedulable: true,
|
|
Credentials: map[string]any{"access_token": "token-103"},
|
|
},
|
|
{
|
|
ID: 104,
|
|
Platform: PlatformAntigravity,
|
|
Type: AccountTypeOAuth,
|
|
Status: StatusActive,
|
|
Schedulable: true,
|
|
AutoPauseOnExpired: true,
|
|
ExpiresAt: &expiredAt,
|
|
Credentials: map[string]any{"access_token": "token-104", "project_id": "proj-104"},
|
|
},
|
|
{
|
|
ID: 106,
|
|
Platform: PlatformAntigravity,
|
|
Type: AccountTypeUpstream,
|
|
Status: StatusActive,
|
|
Schedulable: true,
|
|
Credentials: map[string]any{"access_token": "token-106", "project_id": "proj-106"},
|
|
},
|
|
{
|
|
ID: 105,
|
|
Platform: PlatformOpenAI,
|
|
Status: StatusActive,
|
|
Schedulable: true,
|
|
Credentials: map[string]any{"access_token": "token-105"},
|
|
},
|
|
},
|
|
}
|
|
backend := newFakeLSPoolBackend()
|
|
svc := NewLSPoolBootstrapService(reader, backend, &config.Config{
|
|
Gateway: config.GatewayConfig{
|
|
AntigravityLSWorker: config.GatewayAntigravityLSWorkerConfig{MaxActive: 3},
|
|
},
|
|
})
|
|
|
|
svc.bootstrap(context.Background())
|
|
|
|
require.Equal(t, []string{PlatformAntigravity}, reader.platforms)
|
|
|
|
require.Len(t, backend.getCalls, 1)
|
|
require.Equal(t, fakeLSPoolGetCall{
|
|
AccountID: "101",
|
|
RoutingKey: "",
|
|
ProxyURL: "socks5h://alice:secret@127.0.0.1:1080",
|
|
}, backend.getCalls[0])
|
|
|
|
tokenCall, ok := backend.tokenCalls["101"]
|
|
require.True(t, ok)
|
|
require.Equal(t, "token-101", tokenCall.AccessToken)
|
|
require.Equal(t, "refresh-101", tokenCall.RefreshToken)
|
|
require.Equal(t, expiresAt, tokenCall.ExpiresAt)
|
|
|
|
creditCall, ok := backend.creditCalls["101"]
|
|
require.True(t, ok)
|
|
require.True(t, creditCall.UseAICredits)
|
|
require.NotNil(t, creditCall.AvailableCredits)
|
|
require.Equal(t, int32(120), *creditCall.AvailableCredits)
|
|
require.NotNil(t, creditCall.MinimumCreditAmount)
|
|
require.Equal(t, int32(55), *creditCall.MinimumCreditAmount)
|
|
|
|
require.NotContains(t, backend.tokenCalls, "102")
|
|
require.NotContains(t, backend.tokenCalls, "103")
|
|
require.NotContains(t, backend.tokenCalls, "104")
|
|
require.NotContains(t, backend.tokenCalls, "106")
|
|
}
|
|
|
|
func TestLSPoolBootstrapServiceBootstrapContinuesOnWorkerFailure(t *testing.T) {
|
|
reader := &fakeLSBootstrapAccountReader{
|
|
accounts: []Account{
|
|
{
|
|
ID: 201,
|
|
Platform: PlatformAntigravity,
|
|
Type: AccountTypeOAuth,
|
|
Status: StatusActive,
|
|
Schedulable: true,
|
|
Credentials: map[string]any{"access_token": "token-201", "project_id": "proj-201"},
|
|
},
|
|
{
|
|
ID: 202,
|
|
Platform: PlatformAntigravity,
|
|
Type: AccountTypeOAuth,
|
|
Status: StatusActive,
|
|
Schedulable: true,
|
|
Credentials: map[string]any{"access_token": "token-202", "project_id": "proj-202"},
|
|
},
|
|
},
|
|
}
|
|
backend := newFakeLSPoolBackend()
|
|
backend.getErrs["201"] = errors.New("create failed")
|
|
|
|
svc := NewLSPoolBootstrapService(reader, backend, &config.Config{})
|
|
svc.bootstrap(context.Background())
|
|
|
|
require.Len(t, backend.getCalls, 2)
|
|
require.Contains(t, backend.tokenCalls, "201")
|
|
require.Contains(t, backend.tokenCalls, "202")
|
|
}
|